tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29
30/// Well-known stored procedure IDs.
31///
32/// These are special procedure IDs that SQL Server recognizes
33/// without requiring the procedure name.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u16)]
36pub enum ProcId {
37    /// sp_cursor (0x0001)
38    Cursor = 0x0001,
39    /// sp_cursoropen (0x0002)
40    CursorOpen = 0x0002,
41    /// sp_cursorprepare (0x0003)
42    CursorPrepare = 0x0003,
43    /// sp_cursorexecute (0x0004)
44    CursorExecute = 0x0004,
45    /// sp_cursorprepexec (0x0005)
46    CursorPrepExec = 0x0005,
47    /// sp_cursorunprepare (0x0006)
48    CursorUnprepare = 0x0006,
49    /// sp_cursorfetch (0x0007)
50    CursorFetch = 0x0007,
51    /// sp_cursoroption (0x0008)
52    CursorOption = 0x0008,
53    /// sp_cursorclose (0x0009)
54    CursorClose = 0x0009,
55    /// sp_executesql (0x000A) - Primary method for parameterized queries
56    ExecuteSql = 0x000A,
57    /// sp_prepare (0x000B)
58    Prepare = 0x000B,
59    /// sp_execute (0x000C)
60    Execute = 0x000C,
61    /// sp_prepexec (0x000D) - Prepare and execute in one call
62    PrepExec = 0x000D,
63    /// sp_prepexecrpc (0x000E)
64    PrepExecRpc = 0x000E,
65    /// sp_unprepare (0x000F)
66    Unprepare = 0x000F,
67}
68
69/// RPC option flags.
70#[derive(Debug, Clone, Copy, Default)]
71pub struct RpcOptionFlags {
72    /// Recompile the procedure.
73    pub with_recompile: bool,
74    /// No metadata in response.
75    pub no_metadata: bool,
76    /// Reuse metadata from previous call.
77    pub reuse_metadata: bool,
78}
79
80impl RpcOptionFlags {
81    /// Create new empty flags.
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    /// Set with recompile flag.
87    #[must_use]
88    pub fn with_recompile(mut self, value: bool) -> Self {
89        self.with_recompile = value;
90        self
91    }
92
93    /// Encode to wire format (2 bytes).
94    pub fn encode(&self) -> u16 {
95        let mut flags = 0u16;
96        if self.with_recompile {
97            flags |= 0x0001;
98        }
99        if self.no_metadata {
100            flags |= 0x0002;
101        }
102        if self.reuse_metadata {
103            flags |= 0x0004;
104        }
105        flags
106    }
107}
108
109/// RPC parameter status flags.
110#[derive(Debug, Clone, Copy, Default)]
111pub struct ParamFlags {
112    /// Parameter is passed by reference (OUTPUT parameter).
113    pub by_ref: bool,
114    /// Parameter has a default value.
115    pub default: bool,
116    /// Parameter is encrypted (Always Encrypted).
117    pub encrypted: bool,
118}
119
120impl ParamFlags {
121    /// Create new empty flags.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Set as output parameter.
127    #[must_use]
128    pub fn output(mut self) -> Self {
129        self.by_ref = true;
130        self
131    }
132
133    /// Encode to wire format (1 byte).
134    pub fn encode(&self) -> u8 {
135        let mut flags = 0u8;
136        if self.by_ref {
137            flags |= 0x01;
138        }
139        if self.default {
140            flags |= 0x02;
141        }
142        if self.encrypted {
143            flags |= 0x08;
144        }
145        flags
146    }
147}
148
149/// TDS type information for RPC parameters.
150#[derive(Debug, Clone)]
151pub struct TypeInfo {
152    /// Type ID.
153    pub type_id: u8,
154    /// Maximum length for variable-length types.
155    pub max_length: Option<u16>,
156    /// Precision for numeric types.
157    pub precision: Option<u8>,
158    /// Scale for numeric types.
159    pub scale: Option<u8>,
160    /// Collation for string types.
161    pub collation: Option<[u8; 5]>,
162}
163
164impl TypeInfo {
165    /// Create type info for INT.
166    pub fn int() -> Self {
167        Self {
168            type_id: 0x26, // INTNTYPE (variable-length int)
169            max_length: Some(4),
170            precision: None,
171            scale: None,
172            collation: None,
173        }
174    }
175
176    /// Create type info for BIGINT.
177    pub fn bigint() -> Self {
178        Self {
179            type_id: 0x26, // INTNTYPE
180            max_length: Some(8),
181            precision: None,
182            scale: None,
183            collation: None,
184        }
185    }
186
187    /// Create type info for SMALLINT.
188    pub fn smallint() -> Self {
189        Self {
190            type_id: 0x26, // INTNTYPE
191            max_length: Some(2),
192            precision: None,
193            scale: None,
194            collation: None,
195        }
196    }
197
198    /// Create type info for TINYINT.
199    pub fn tinyint() -> Self {
200        Self {
201            type_id: 0x26, // INTNTYPE
202            max_length: Some(1),
203            precision: None,
204            scale: None,
205            collation: None,
206        }
207    }
208
209    /// Create type info for BIT.
210    pub fn bit() -> Self {
211        Self {
212            type_id: 0x68, // BITNTYPE
213            max_length: Some(1),
214            precision: None,
215            scale: None,
216            collation: None,
217        }
218    }
219
220    /// Create type info for FLOAT.
221    pub fn float() -> Self {
222        Self {
223            type_id: 0x6D, // FLTNTYPE
224            max_length: Some(8),
225            precision: None,
226            scale: None,
227            collation: None,
228        }
229    }
230
231    /// Create type info for REAL.
232    pub fn real() -> Self {
233        Self {
234            type_id: 0x6D, // FLTNTYPE
235            max_length: Some(4),
236            precision: None,
237            scale: None,
238            collation: None,
239        }
240    }
241
242    /// Create type info for NVARCHAR with max length.
243    pub fn nvarchar(max_len: u16) -> Self {
244        Self {
245            type_id: 0xE7,                 // NVARCHARTYPE
246            max_length: Some(max_len * 2), // UTF-16, so double the char count
247            precision: None,
248            scale: None,
249            // Default collation (Latin1_General_CI_AS equivalent)
250            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
251        }
252    }
253
254    /// Create type info for NVARCHAR(MAX).
255    pub fn nvarchar_max() -> Self {
256        Self {
257            type_id: 0xE7,            // NVARCHARTYPE
258            max_length: Some(0xFFFF), // MAX indicator
259            precision: None,
260            scale: None,
261            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262        }
263    }
264
265    /// Create type info for VARBINARY with max length.
266    pub fn varbinary(max_len: u16) -> Self {
267        Self {
268            type_id: 0xA5, // BIGVARBINTYPE
269            max_length: Some(max_len),
270            precision: None,
271            scale: None,
272            collation: None,
273        }
274    }
275
276    /// Create type info for UNIQUEIDENTIFIER.
277    pub fn uniqueidentifier() -> Self {
278        Self {
279            type_id: 0x24, // GUIDTYPE
280            max_length: Some(16),
281            precision: None,
282            scale: None,
283            collation: None,
284        }
285    }
286
287    /// Create type info for DATE.
288    pub fn date() -> Self {
289        Self {
290            type_id: 0x28, // DATETYPE
291            max_length: None,
292            precision: None,
293            scale: None,
294            collation: None,
295        }
296    }
297
298    /// Create type info for DATETIME2.
299    pub fn datetime2(scale: u8) -> Self {
300        Self {
301            type_id: 0x2A, // DATETIME2TYPE
302            max_length: None,
303            precision: None,
304            scale: Some(scale),
305            collation: None,
306        }
307    }
308
309    /// Create type info for DECIMAL.
310    pub fn decimal(precision: u8, scale: u8) -> Self {
311        Self {
312            type_id: 0x6C,        // DECIMALNTYPE
313            max_length: Some(17), // Max decimal size
314            precision: Some(precision),
315            scale: Some(scale),
316            collation: None,
317        }
318    }
319
320    /// Encode type info to buffer.
321    pub fn encode(&self, buf: &mut BytesMut) {
322        buf.put_u8(self.type_id);
323
324        // Variable-length types need max length
325        match self.type_id {
326            0x26 | 0x68 | 0x6D => {
327                // INTNTYPE, BITNTYPE, FLTNTYPE
328                if let Some(len) = self.max_length {
329                    buf.put_u8(len as u8);
330                }
331            }
332            0xE7 | 0xA5 | 0xEF => {
333                // NVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
334                if let Some(len) = self.max_length {
335                    buf.put_u16_le(len);
336                }
337                // Collation for string types
338                if let Some(collation) = self.collation {
339                    buf.put_slice(&collation);
340                }
341            }
342            0x24 => {
343                // GUIDTYPE
344                if let Some(len) = self.max_length {
345                    buf.put_u8(len as u8);
346                }
347            }
348            0x29..=0x2B => {
349                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
350                if let Some(scale) = self.scale {
351                    buf.put_u8(scale);
352                }
353            }
354            0x6C | 0x6A => {
355                // DECIMALNTYPE, NUMERICNTYPE
356                if let Some(len) = self.max_length {
357                    buf.put_u8(len as u8);
358                }
359                if let Some(precision) = self.precision {
360                    buf.put_u8(precision);
361                }
362                if let Some(scale) = self.scale {
363                    buf.put_u8(scale);
364                }
365            }
366            _ => {}
367        }
368    }
369}
370
371/// An RPC parameter.
372#[derive(Debug, Clone)]
373pub struct RpcParam {
374    /// Parameter name (can be empty for positional params).
375    pub name: String,
376    /// Status flags.
377    pub flags: ParamFlags,
378    /// Type information.
379    pub type_info: TypeInfo,
380    /// Parameter value (raw bytes).
381    pub value: Option<Bytes>,
382}
383
384impl RpcParam {
385    /// Create a new parameter with a value.
386    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
387        Self {
388            name: name.into(),
389            flags: ParamFlags::default(),
390            type_info,
391            value: Some(value),
392        }
393    }
394
395    /// Create a NULL parameter.
396    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
397        Self {
398            name: name.into(),
399            flags: ParamFlags::default(),
400            type_info,
401            value: None,
402        }
403    }
404
405    /// Create an INT parameter.
406    pub fn int(name: impl Into<String>, value: i32) -> Self {
407        let mut buf = BytesMut::with_capacity(4);
408        buf.put_i32_le(value);
409        Self::new(name, TypeInfo::int(), buf.freeze())
410    }
411
412    /// Create a BIGINT parameter.
413    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
414        let mut buf = BytesMut::with_capacity(8);
415        buf.put_i64_le(value);
416        Self::new(name, TypeInfo::bigint(), buf.freeze())
417    }
418
419    /// Create an NVARCHAR parameter.
420    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
421        let mut buf = BytesMut::new();
422        // Encode as UTF-16LE
423        for code_unit in value.encode_utf16() {
424            buf.put_u16_le(code_unit);
425        }
426        let char_len = value.chars().count();
427        let type_info = if char_len > 4000 {
428            TypeInfo::nvarchar_max()
429        } else {
430            TypeInfo::nvarchar(char_len.max(1) as u16)
431        };
432        Self::new(name, type_info, buf.freeze())
433    }
434
435    /// Mark as output parameter.
436    #[must_use]
437    pub fn as_output(mut self) -> Self {
438        self.flags = self.flags.output();
439        self
440    }
441
442    /// Encode the parameter to buffer.
443    pub fn encode(&self, buf: &mut BytesMut) {
444        // Parameter name (B_VARCHAR - length-prefixed)
445        let name_len = self.name.encode_utf16().count() as u8;
446        buf.put_u8(name_len);
447        if name_len > 0 {
448            for code_unit in self.name.encode_utf16() {
449                buf.put_u16_le(code_unit);
450            }
451        }
452
453        // Status flags
454        buf.put_u8(self.flags.encode());
455
456        // Type info
457        self.type_info.encode(buf);
458
459        // Value
460        if let Some(ref value) = self.value {
461            // Length prefix based on type
462            match self.type_info.type_id {
463                0x26 => {
464                    // INTNTYPE
465                    buf.put_u8(value.len() as u8);
466                    buf.put_slice(value);
467                }
468                0x68 | 0x6D => {
469                    // BITNTYPE, FLTNTYPE
470                    buf.put_u8(value.len() as u8);
471                    buf.put_slice(value);
472                }
473                0xE7 | 0xA5 => {
474                    // NVARCHARTYPE, BIGVARBINTYPE
475                    if self.type_info.max_length == Some(0xFFFF) {
476                        // MAX type - use PLP format
477                        // For simplicity, send as single chunk
478                        let total_len = value.len() as u64;
479                        buf.put_u64_le(total_len);
480                        buf.put_u32_le(value.len() as u32);
481                        buf.put_slice(value);
482                        buf.put_u32_le(0); // Terminator
483                    } else {
484                        buf.put_u16_le(value.len() as u16);
485                        buf.put_slice(value);
486                    }
487                }
488                0x24 => {
489                    // GUIDTYPE
490                    buf.put_u8(value.len() as u8);
491                    buf.put_slice(value);
492                }
493                0x28 => {
494                    // DATETYPE (fixed 3 bytes)
495                    buf.put_slice(value);
496                }
497                0x2A => {
498                    // DATETIME2TYPE
499                    buf.put_u8(value.len() as u8);
500                    buf.put_slice(value);
501                }
502                0x6C => {
503                    // DECIMALNTYPE
504                    buf.put_u8(value.len() as u8);
505                    buf.put_slice(value);
506                }
507                _ => {
508                    // Generic: assume length-prefixed
509                    buf.put_u8(value.len() as u8);
510                    buf.put_slice(value);
511                }
512            }
513        } else {
514            // NULL value
515            match self.type_info.type_id {
516                0xE7 | 0xA5 => {
517                    // Variable-length types use 0xFFFF for NULL
518                    if self.type_info.max_length == Some(0xFFFF) {
519                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
520                    } else {
521                        buf.put_u16_le(0xFFFF);
522                    }
523                }
524                _ => {
525                    buf.put_u8(0); // Zero-length for NULL
526                }
527            }
528        }
529    }
530}
531
532/// RPC request builder.
533#[derive(Debug, Clone)]
534pub struct RpcRequest {
535    /// Procedure name (if using named procedure).
536    proc_name: Option<String>,
537    /// Procedure ID (if using well-known procedure).
538    proc_id: Option<ProcId>,
539    /// Option flags.
540    options: RpcOptionFlags,
541    /// Parameters.
542    params: Vec<RpcParam>,
543}
544
545impl RpcRequest {
546    /// Create a new RPC request for a named procedure.
547    pub fn named(proc_name: impl Into<String>) -> Self {
548        Self {
549            proc_name: Some(proc_name.into()),
550            proc_id: None,
551            options: RpcOptionFlags::default(),
552            params: Vec::new(),
553        }
554    }
555
556    /// Create a new RPC request for a well-known procedure.
557    pub fn by_id(proc_id: ProcId) -> Self {
558        Self {
559            proc_name: None,
560            proc_id: Some(proc_id),
561            options: RpcOptionFlags::default(),
562            params: Vec::new(),
563        }
564    }
565
566    /// Create an sp_executesql request.
567    ///
568    /// This is the primary method for parameterized queries.
569    ///
570    /// # Example
571    ///
572    /// ```
573    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
574    ///
575    /// let rpc = RpcRequest::execute_sql(
576    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
577    ///     vec![
578    ///         RpcParam::int("@p1", 42),
579    ///         RpcParam::nvarchar("@p2", "Alice"),
580    ///     ],
581    /// );
582    /// ```
583    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
584        let mut request = Self::by_id(ProcId::ExecuteSql);
585
586        // First parameter: the SQL statement (NVARCHAR(MAX))
587        request.params.push(RpcParam::nvarchar("", sql));
588
589        // Second parameter: parameter declarations
590        if !params.is_empty() {
591            let declarations = Self::build_param_declarations(&params);
592            request.params.push(RpcParam::nvarchar("", &declarations));
593        }
594
595        // Add the actual parameters
596        request.params.extend(params);
597
598        request
599    }
600
601    /// Build parameter declaration string for sp_executesql.
602    fn build_param_declarations(params: &[RpcParam]) -> String {
603        params
604            .iter()
605            .map(|p| {
606                let name = if p.name.starts_with('@') {
607                    p.name.clone()
608                } else if p.name.is_empty() {
609                    // Generate positional name
610                    format!(
611                        "@p{}",
612                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
613                    )
614                } else {
615                    format!("@{}", p.name)
616                };
617
618                let type_name: String = match p.type_info.type_id {
619                    0x26 => match p.type_info.max_length {
620                        Some(1) => "tinyint".to_string(),
621                        Some(2) => "smallint".to_string(),
622                        Some(4) => "int".to_string(),
623                        Some(8) => "bigint".to_string(),
624                        _ => "int".to_string(),
625                    },
626                    0x68 => "bit".to_string(),
627                    0x6D => match p.type_info.max_length {
628                        Some(4) => "real".to_string(),
629                        _ => "float".to_string(),
630                    },
631                    0xE7 => {
632                        if p.type_info.max_length == Some(0xFFFF) {
633                            "nvarchar(max)".to_string()
634                        } else {
635                            let len = p.type_info.max_length.unwrap_or(4000) / 2;
636                            format!("nvarchar({})", len)
637                        }
638                    }
639                    0xA5 => {
640                        if p.type_info.max_length == Some(0xFFFF) {
641                            "varbinary(max)".to_string()
642                        } else {
643                            let len = p.type_info.max_length.unwrap_or(8000);
644                            format!("varbinary({})", len)
645                        }
646                    }
647                    0x24 => "uniqueidentifier".to_string(),
648                    0x28 => "date".to_string(),
649                    0x2A => {
650                        let scale = p.type_info.scale.unwrap_or(7);
651                        format!("datetime2({})", scale)
652                    }
653                    0x6C => {
654                        let precision = p.type_info.precision.unwrap_or(18);
655                        let scale = p.type_info.scale.unwrap_or(0);
656                        format!("decimal({}, {})", precision, scale)
657                    }
658                    _ => "sql_variant".to_string(),
659                };
660
661                format!("{} {}", name, type_name)
662            })
663            .collect::<Vec<_>>()
664            .join(", ")
665    }
666
667    /// Create an sp_prepare request.
668    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
669        let mut request = Self::by_id(ProcId::Prepare);
670
671        // OUT: handle (INT)
672        request
673            .params
674            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
675
676        // Param declarations
677        let declarations = Self::build_param_declarations(params);
678        request
679            .params
680            .push(RpcParam::nvarchar("@params", &declarations));
681
682        // SQL statement
683        request.params.push(RpcParam::nvarchar("@stmt", sql));
684
685        // Options (1 = WITH RECOMPILE)
686        request.params.push(RpcParam::int("@options", 1));
687
688        request
689    }
690
691    /// Create an sp_execute request.
692    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
693        let mut request = Self::by_id(ProcId::Execute);
694
695        // Handle from sp_prepare
696        request.params.push(RpcParam::int("@handle", handle));
697
698        // Add parameters
699        request.params.extend(params);
700
701        request
702    }
703
704    /// Create an sp_unprepare request.
705    pub fn unprepare(handle: i32) -> Self {
706        let mut request = Self::by_id(ProcId::Unprepare);
707        request.params.push(RpcParam::int("@handle", handle));
708        request
709    }
710
711    /// Set option flags.
712    #[must_use]
713    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
714        self.options = options;
715        self
716    }
717
718    /// Add a parameter.
719    #[must_use]
720    pub fn param(mut self, param: RpcParam) -> Self {
721        self.params.push(param);
722        self
723    }
724
725    /// Encode the RPC request to bytes (auto-commit mode).
726    ///
727    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
728    #[must_use]
729    pub fn encode(&self) -> Bytes {
730        self.encode_with_transaction(0)
731    }
732
733    /// Encode the RPC request with a transaction descriptor.
734    ///
735    /// Per MS-TDS spec, when executing within an explicit transaction:
736    /// - The `transaction_descriptor` MUST be the value returned by the server
737    ///   in the BeginTransaction EnvChange token.
738    /// - For auto-commit mode (no explicit transaction), use 0.
739    ///
740    /// # Arguments
741    ///
742    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
743    ///   or 0 for auto-commit mode.
744    #[must_use]
745    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
746        let mut buf = BytesMut::with_capacity(256);
747
748        // ALL_HEADERS - TDS 7.2+ requires this section
749        // Total length placeholder (will be filled in)
750        let all_headers_start = buf.len();
751        buf.put_u32_le(0); // Total length placeholder
752
753        // Transaction descriptor header (required for RPC)
754        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
755        buf.put_u32_le(18); // Header length
756        buf.put_u16_le(0x0002); // Header type: transaction descriptor
757        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
758        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
759
760        // Fill in ALL_HEADERS total length
761        let all_headers_len = buf.len() - all_headers_start;
762        let len_bytes = (all_headers_len as u32).to_le_bytes();
763        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
764
765        // Procedure name or ID
766        if let Some(proc_id) = self.proc_id {
767            // Use PROCID format
768            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
769            buf.put_u16_le(proc_id as u16);
770        } else if let Some(ref proc_name) = self.proc_name {
771            // Use procedure name
772            let name_len = proc_name.encode_utf16().count() as u16;
773            buf.put_u16_le(name_len);
774            write_utf16_string(&mut buf, proc_name);
775        }
776
777        // Option flags
778        buf.put_u16_le(self.options.encode());
779
780        // Parameters
781        for param in &self.params {
782            param.encode(&mut buf);
783        }
784
785        buf.freeze()
786    }
787}
788
789#[cfg(test)]
790#[allow(clippy::unwrap_used)]
791mod tests {
792    use super::*;
793
794    #[test]
795    fn test_proc_id_values() {
796        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
797        assert_eq!(ProcId::Prepare as u16, 0x000B);
798        assert_eq!(ProcId::Execute as u16, 0x000C);
799        assert_eq!(ProcId::Unprepare as u16, 0x000F);
800    }
801
802    #[test]
803    fn test_option_flags_encode() {
804        let flags = RpcOptionFlags::new().with_recompile(true);
805        assert_eq!(flags.encode(), 0x0001);
806    }
807
808    #[test]
809    fn test_param_flags_encode() {
810        let flags = ParamFlags::new().output();
811        assert_eq!(flags.encode(), 0x01);
812    }
813
814    #[test]
815    fn test_int_param() {
816        let param = RpcParam::int("@p1", 42);
817        assert_eq!(param.name, "@p1");
818        assert_eq!(param.type_info.type_id, 0x26);
819        assert!(param.value.is_some());
820    }
821
822    #[test]
823    fn test_nvarchar_param() {
824        let param = RpcParam::nvarchar("@name", "Alice");
825        assert_eq!(param.name, "@name");
826        assert_eq!(param.type_info.type_id, 0xE7);
827        // UTF-16 encoded "Alice" = 10 bytes
828        assert_eq!(param.value.as_ref().unwrap().len(), 10);
829    }
830
831    #[test]
832    fn test_execute_sql_request() {
833        let rpc = RpcRequest::execute_sql(
834            "SELECT * FROM users WHERE id = @p1",
835            vec![RpcParam::int("@p1", 42)],
836        );
837
838        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
839        // SQL statement + param declarations + actual params
840        assert_eq!(rpc.params.len(), 3);
841    }
842
843    #[test]
844    fn test_param_declarations() {
845        let params = vec![
846            RpcParam::int("@p1", 42),
847            RpcParam::nvarchar("@name", "Alice"),
848        ];
849
850        let decls = RpcRequest::build_param_declarations(&params);
851        assert!(decls.contains("@p1 int"));
852        assert!(decls.contains("@name nvarchar"));
853    }
854
855    #[test]
856    fn test_rpc_encode_not_empty() {
857        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
858        let encoded = rpc.encode();
859        assert!(!encoded.is_empty());
860    }
861
862    #[test]
863    fn test_prepare_request() {
864        let rpc = RpcRequest::prepare(
865            "SELECT * FROM users WHERE id = @p1",
866            &[RpcParam::int("@p1", 0)],
867        );
868
869        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
870        // handle (output), params, stmt, options
871        assert_eq!(rpc.params.len(), 4);
872        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
873    }
874
875    #[test]
876    fn test_execute_request() {
877        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
878
879        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
880        assert_eq!(rpc.params.len(), 2); // handle + param
881    }
882
883    #[test]
884    fn test_unprepare_request() {
885        let rpc = RpcRequest::unprepare(123);
886
887        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
888        assert_eq!(rpc.params.len(), 1); // just the handle
889    }
890}