Skip to main content

tds_protocol/
rpc.rs

1//! RPC (Remote Procedure Call) request encoding.
2//!
3//! This module provides encoding for RPC requests (packet type 0x03).
4//! RPC is used for calling stored procedures and sp_executesql for parameterized queries.
5//!
6//! ## sp_executesql
7//!
8//! The primary use case is `sp_executesql` for parameterized queries, which prevents
9//! SQL injection and enables query plan caching.
10//!
11//! ## Wire Format
12//!
13//! ```text
14//! RPC Request:
15//! +-------------------+
16//! | ALL_HEADERS       | (TDS 7.2+, optional)
17//! +-------------------+
18//! | ProcName/ProcID   | (procedure identifier)
19//! +-------------------+
20//! | Option Flags      | (2 bytes)
21//! +-------------------+
22//! | Parameters        | (repeated)
23//! +-------------------+
24//! ```
25
26use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::crypto::EncryptionTypeWire;
30use crate::prelude::*;
31use crate::token::Collation;
32
33/// Well-known stored procedure IDs.
34///
35/// These are special procedure IDs that SQL Server recognizes
36/// without requiring the procedure name.
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u16)]
39#[non_exhaustive]
40pub enum ProcId {
41    /// sp_cursor (0x0001)
42    Cursor = 0x0001,
43    /// sp_cursoropen (0x0002)
44    CursorOpen = 0x0002,
45    /// sp_cursorprepare (0x0003)
46    CursorPrepare = 0x0003,
47    /// sp_cursorexecute (0x0004)
48    CursorExecute = 0x0004,
49    /// sp_cursorprepexec (0x0005)
50    CursorPrepExec = 0x0005,
51    /// sp_cursorunprepare (0x0006)
52    CursorUnprepare = 0x0006,
53    /// sp_cursorfetch (0x0007)
54    CursorFetch = 0x0007,
55    /// sp_cursoroption (0x0008)
56    CursorOption = 0x0008,
57    /// sp_cursorclose (0x0009)
58    CursorClose = 0x0009,
59    /// sp_executesql (0x000A) - Primary method for parameterized queries
60    ExecuteSql = 0x000A,
61    /// sp_prepare (0x000B)
62    Prepare = 0x000B,
63    /// sp_execute (0x000C)
64    Execute = 0x000C,
65    /// sp_prepexec (0x000D) - Prepare and execute in one call
66    PrepExec = 0x000D,
67    /// sp_prepexecrpc (0x000E)
68    PrepExecRpc = 0x000E,
69    /// sp_unprepare (0x000F)
70    Unprepare = 0x000F,
71}
72
73/// RPC option flags.
74#[derive(Debug, Clone, Copy, Default)]
75pub struct RpcOptionFlags {
76    /// Recompile the procedure.
77    pub with_recompile: bool,
78    /// No metadata in response.
79    pub no_metadata: bool,
80    /// Reuse metadata from previous call.
81    pub reuse_metadata: bool,
82}
83
84impl RpcOptionFlags {
85    /// Create new empty flags.
86    pub fn new() -> Self {
87        Self::default()
88    }
89
90    /// Set with recompile flag.
91    #[must_use]
92    pub fn with_recompile(mut self, value: bool) -> Self {
93        self.with_recompile = value;
94        self
95    }
96
97    /// Encode to wire format (2 bytes).
98    pub fn encode(&self) -> u16 {
99        let mut flags = 0u16;
100        if self.with_recompile {
101            flags |= 0x0001;
102        }
103        if self.no_metadata {
104            flags |= 0x0002;
105        }
106        if self.reuse_metadata {
107            flags |= 0x0004;
108        }
109        flags
110    }
111}
112
113/// RPC parameter status flags.
114#[derive(Debug, Clone, Copy, Default)]
115pub struct ParamFlags {
116    /// Parameter is passed by reference (OUTPUT parameter).
117    pub by_ref: bool,
118    /// Parameter has a default value.
119    pub default: bool,
120    /// Parameter is encrypted (Always Encrypted).
121    pub encrypted: bool,
122}
123
124impl ParamFlags {
125    /// Create new empty flags.
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    /// Set as output parameter.
131    #[must_use]
132    pub fn output(mut self) -> Self {
133        self.by_ref = true;
134        self
135    }
136
137    /// Encode to wire format (1 byte).
138    pub fn encode(&self) -> u8 {
139        let mut flags = 0u8;
140        if self.by_ref {
141            flags |= 0x01;
142        }
143        if self.default {
144            flags |= 0x02;
145        }
146        if self.encrypted {
147            flags |= 0x08;
148        }
149        flags
150    }
151}
152
153/// TDS type information for RPC parameters.
154#[derive(Debug, Clone)]
155pub struct TypeInfo {
156    /// Type ID.
157    pub type_id: u8,
158    /// Maximum length for variable-length types.
159    pub max_length: Option<u16>,
160    /// Precision for numeric types.
161    pub precision: Option<u8>,
162    /// Scale for numeric types.
163    pub scale: Option<u8>,
164    /// Collation for string types.
165    pub collation: Option<[u8; 5]>,
166    /// TVP type name (e.g., "dbo.IntIdList") for Table-Valued Parameters.
167    pub tvp_type_name: Option<String>,
168}
169
170impl TypeInfo {
171    /// Create type info for INT.
172    pub fn int() -> Self {
173        Self {
174            type_id: 0x26, // INTNTYPE (variable-length int)
175            max_length: Some(4),
176            precision: None,
177            scale: None,
178            collation: None,
179            tvp_type_name: None,
180        }
181    }
182
183    /// Create type info for BIGINT.
184    pub fn bigint() -> Self {
185        Self {
186            type_id: 0x26, // INTNTYPE
187            max_length: Some(8),
188            precision: None,
189            scale: None,
190            collation: None,
191            tvp_type_name: None,
192        }
193    }
194
195    /// Create type info for SMALLINT.
196    pub fn smallint() -> Self {
197        Self {
198            type_id: 0x26, // INTNTYPE
199            max_length: Some(2),
200            precision: None,
201            scale: None,
202            collation: None,
203            tvp_type_name: None,
204        }
205    }
206
207    /// Create type info for TINYINT.
208    pub fn tinyint() -> Self {
209        Self {
210            type_id: 0x26, // INTNTYPE
211            max_length: Some(1),
212            precision: None,
213            scale: None,
214            collation: None,
215            tvp_type_name: None,
216        }
217    }
218
219    /// Create type info for BIT.
220    pub fn bit() -> Self {
221        Self {
222            type_id: 0x68, // BITNTYPE
223            max_length: Some(1),
224            precision: None,
225            scale: None,
226            collation: None,
227            tvp_type_name: None,
228        }
229    }
230
231    /// Create type info for FLOAT.
232    pub fn float() -> Self {
233        Self {
234            type_id: 0x6D, // FLTNTYPE
235            max_length: Some(8),
236            precision: None,
237            scale: None,
238            collation: None,
239            tvp_type_name: None,
240        }
241    }
242
243    /// Create type info for REAL.
244    pub fn real() -> Self {
245        Self {
246            type_id: 0x6D, // FLTNTYPE
247            max_length: Some(4),
248            precision: None,
249            scale: None,
250            collation: None,
251            tvp_type_name: None,
252        }
253    }
254
255    /// Create type info for NVARCHAR with max length.
256    pub fn nvarchar(max_len: u16) -> Self {
257        Self {
258            type_id: 0xE7,                 // NVARCHARTYPE
259            max_length: Some(max_len * 2), // UTF-16, so double the char count
260            precision: None,
261            scale: None,
262            // Default collation (Latin1_General_CI_AS equivalent)
263            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
264            tvp_type_name: None,
265        }
266    }
267
268    /// Create type info for NVARCHAR(MAX).
269    pub fn nvarchar_max() -> Self {
270        Self {
271            type_id: 0xE7,            // NVARCHARTYPE
272            max_length: Some(0xFFFF), // MAX indicator
273            precision: None,
274            scale: None,
275            collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
276            tvp_type_name: None,
277        }
278    }
279
280    /// Default collation bytes: Latin1_General_CI_AS (LCID 0x0409, sort ID 0x34).
281    const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
282
283    /// Create type info for VARCHAR with max length (in bytes).
284    pub fn varchar(max_len: u16) -> Self {
285        Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
286    }
287
288    /// Create type info for VARCHAR with max length and explicit collation.
289    pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
290        Self {
291            type_id: 0xA7, // BIGVARCHARTYPE
292            max_length: Some(max_len),
293            precision: None,
294            scale: None,
295            collation: Some(collation),
296            tvp_type_name: None,
297        }
298    }
299
300    /// Create type info for VARCHAR(MAX).
301    pub fn varchar_max() -> Self {
302        Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
303    }
304
305    /// Create type info for VARCHAR(MAX) with explicit collation.
306    pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
307        Self {
308            type_id: 0xA7,            // BIGVARCHARTYPE
309            max_length: Some(0xFFFF), // MAX indicator
310            precision: None,
311            scale: None,
312            collation: Some(collation),
313            tvp_type_name: None,
314        }
315    }
316
317    /// Create type info for VARBINARY with max length.
318    pub fn varbinary(max_len: u16) -> Self {
319        Self {
320            type_id: 0xA5, // BIGVARBINTYPE
321            max_length: Some(max_len),
322            precision: None,
323            scale: None,
324            collation: None,
325            tvp_type_name: None,
326        }
327    }
328
329    /// Create type info for VARBINARY(MAX).
330    pub fn varbinary_max() -> Self {
331        Self {
332            type_id: 0xA5,            // BIGVARBINTYPE
333            max_length: Some(0xFFFF), // MAX indicator β€” triggers PLP encoding
334            precision: None,
335            scale: None,
336            collation: None,
337            tvp_type_name: None,
338        }
339    }
340
341    /// Create type info for UNIQUEIDENTIFIER.
342    pub fn uniqueidentifier() -> Self {
343        Self {
344            type_id: 0x24, // GUIDTYPE
345            max_length: Some(16),
346            precision: None,
347            scale: None,
348            collation: None,
349            tvp_type_name: None,
350        }
351    }
352
353    /// Create type info for UNIQUEIDENTIFIER.
354    pub fn uuid() -> Self {
355        Self {
356            type_id: 0x24, // GUIDTYPE
357            max_length: Some(16),
358            precision: None,
359            scale: None,
360            collation: None,
361            tvp_type_name: None,
362        }
363    }
364
365    /// Create type info for DATE.
366    pub fn date() -> Self {
367        Self {
368            type_id: 0x28, // DATETYPE
369            max_length: None,
370            precision: None,
371            scale: None,
372            collation: None,
373            tvp_type_name: None,
374        }
375    }
376
377    /// Create type info for TIME.
378    pub fn time(scale: u8) -> Self {
379        Self {
380            type_id: 0x29, // TIMETYPE
381            max_length: None,
382            precision: None,
383            scale: Some(scale),
384            collation: None,
385            tvp_type_name: None,
386        }
387    }
388
389    /// Create type info for DATETIME2.
390    pub fn datetime2(scale: u8) -> Self {
391        Self {
392            type_id: 0x2A, // DATETIME2TYPE
393            max_length: None,
394            precision: None,
395            scale: Some(scale),
396            collation: None,
397            tvp_type_name: None,
398        }
399    }
400
401    /// Create type info for DATETIMEOFFSET.
402    pub fn datetimeoffset(scale: u8) -> Self {
403        Self {
404            type_id: 0x2B, // DATETIMEOFFSETTYPE
405            max_length: None,
406            precision: None,
407            scale: Some(scale),
408            collation: None,
409            tvp_type_name: None,
410        }
411    }
412
413    /// Create type info for DECIMAL.
414    pub fn decimal(precision: u8, scale: u8) -> Self {
415        Self {
416            type_id: 0x6C,        // DECIMALNTYPE
417            max_length: Some(17), // Max decimal size
418            precision: Some(precision),
419            scale: Some(scale),
420            collation: None,
421            tvp_type_name: None,
422        }
423    }
424
425    /// Create type info for MONEY (8-byte scaled integer via MONEYN / 0x6E).
426    pub fn money() -> Self {
427        Self {
428            type_id: 0x6E, // MONEYNTYPE
429            max_length: Some(8),
430            precision: None,
431            scale: None,
432            collation: None,
433            tvp_type_name: None,
434        }
435    }
436
437    /// Create type info for SMALLMONEY (4-byte scaled integer via MONEYN / 0x6E).
438    pub fn smallmoney() -> Self {
439        Self {
440            type_id: 0x6E, // MONEYNTYPE
441            max_length: Some(4),
442            precision: None,
443            scale: None,
444            collation: None,
445            tvp_type_name: None,
446        }
447    }
448
449    /// Create type info for SMALLDATETIME (4-byte days+minutes via DATETIMEN / 0x6F).
450    pub fn smalldatetime() -> Self {
451        Self {
452            type_id: 0x6F, // DATETIMENTYPE
453            max_length: Some(4),
454            precision: None,
455            scale: None,
456            collation: None,
457            tvp_type_name: None,
458        }
459    }
460
461    /// Create type info for a Table-Valued Parameter.
462    ///
463    /// # Arguments
464    /// * `type_name` - The fully qualified table type name (e.g., "dbo.IntIdList")
465    pub fn tvp(type_name: impl Into<String>) -> Self {
466        Self {
467            type_id: 0xF3, // TVP type
468            max_length: None,
469            precision: None,
470            scale: None,
471            collation: None,
472            tvp_type_name: Some(type_name.into()),
473        }
474    }
475
476    /// Encode type info to buffer.
477    pub fn encode(&self, buf: &mut BytesMut) {
478        // TVP (0xF3) has type_id embedded in the value data itself
479        // (written by TvpEncoder::encode_metadata), so don't write it here
480        if self.type_id != 0xF3 {
481            buf.put_u8(self.type_id);
482        }
483
484        // Variable-length types need max length
485        match self.type_id {
486            0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
487                // INTNTYPE, BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
488                if let Some(len) = self.max_length {
489                    buf.put_u8(len as u8);
490                }
491            }
492            0xE7 | 0xA7 | 0xA5 | 0xEF => {
493                // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE, NCHARTYPE
494                if let Some(len) = self.max_length {
495                    buf.put_u16_le(len);
496                }
497                // Collation for string types
498                if let Some(collation) = self.collation {
499                    buf.put_slice(&collation);
500                }
501            }
502            0x24 => {
503                // GUIDTYPE
504                if let Some(len) = self.max_length {
505                    buf.put_u8(len as u8);
506                }
507            }
508            0x29..=0x2B => {
509                // DATETIME2TYPE, TIMETYPE, DATETIMEOFFSETTYPE
510                if let Some(scale) = self.scale {
511                    buf.put_u8(scale);
512                }
513            }
514            0x6C | 0x6A => {
515                // DECIMALNTYPE, NUMERICNTYPE
516                if let Some(len) = self.max_length {
517                    buf.put_u8(len as u8);
518                }
519                if let Some(precision) = self.precision {
520                    buf.put_u8(precision);
521                }
522                if let Some(scale) = self.scale {
523                    buf.put_u8(scale);
524                }
525            }
526            _ => {}
527        }
528    }
529}
530
531/// Always Encrypted cipher metadata, written after an encrypted parameter's
532/// ciphertext value (MS-TDS 2.2.6.6 `CryptoMetadata`).
533///
534/// It tells the server how to validate and route the encrypted value: the
535/// plaintext column type (`BaseTypeInfo`), the AEAD algorithm and encryption
536/// mode, which column encryption key, and the normalization rule version.
537#[derive(Debug, Clone)]
538pub struct EncryptedParamMetadata {
539    /// Type info of the plaintext column (the `BaseTypeInfo`).
540    pub base_type_info: TypeInfo,
541    /// Encryption algorithm ID (2 = AEAD_AES_256_CBC_HMAC_SHA256).
542    pub algorithm_id: u8,
543    /// Deterministic or randomized encryption.
544    pub encryption_type: EncryptionTypeWire,
545    /// Database ID of the column encryption key.
546    pub database_id: u32,
547    /// Column encryption key ID.
548    pub cek_id: u32,
549    /// Column encryption key version.
550    pub cek_version: u32,
551    /// Column encryption key metadata version.
552    pub cek_md_version: u64,
553    /// Normalization rule version applied to the plaintext.
554    pub normalization_rule_version: u8,
555}
556
557impl EncryptedParamMetadata {
558    /// Encode the cipher-metadata trailer to the buffer, in the order the
559    /// server expects it after the ciphertext value.
560    pub fn encode(&self, buf: &mut BytesMut) {
561        self.base_type_info.encode(buf);
562        buf.put_u8(self.algorithm_id);
563        buf.put_u8(self.encryption_type.to_u8());
564        buf.put_u32_le(self.database_id);
565        buf.put_u32_le(self.cek_id);
566        buf.put_u32_le(self.cek_version);
567        buf.put_u64_le(self.cek_md_version);
568        buf.put_u8(self.normalization_rule_version);
569    }
570}
571
572/// An RPC parameter.
573#[derive(Debug, Clone)]
574pub struct RpcParam {
575    /// Parameter name (can be empty for positional params).
576    pub name: String,
577    /// Status flags.
578    pub flags: ParamFlags,
579    /// Type information.
580    pub type_info: TypeInfo,
581    /// Parameter value (raw bytes).
582    pub value: Option<Bytes>,
583    /// Always Encrypted cipher metadata, written after the value when the
584    /// parameter is encrypted. `None` for ordinary parameters.
585    pub crypto_metadata: Option<EncryptedParamMetadata>,
586}
587
588impl RpcParam {
589    /// Create a new parameter with a value.
590    pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
591        Self {
592            name: name.into(),
593            flags: ParamFlags::default(),
594            type_info,
595            value: Some(value),
596            crypto_metadata: None,
597        }
598    }
599
600    /// Create a NULL parameter.
601    pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
602        Self {
603            name: name.into(),
604            flags: ParamFlags::default(),
605            type_info,
606            value: None,
607            crypto_metadata: None,
608        }
609    }
610
611    /// Create an Always Encrypted parameter.
612    ///
613    /// `ciphertext` is the AEAD-encrypted, normalized value; `metadata` is the
614    /// cipher info the server needs to validate and route it. On the wire the
615    /// value is carried as `BIGVARBINARY(max)` with the `fEncrypted` status bit
616    /// set and the [`EncryptedParamMetadata`] trailer after the value.
617    pub fn encrypted(
618        name: impl Into<String>,
619        ciphertext: Bytes,
620        metadata: EncryptedParamMetadata,
621    ) -> Self {
622        Self {
623            name: name.into(),
624            flags: ParamFlags {
625                encrypted: true,
626                ..ParamFlags::default()
627            },
628            type_info: TypeInfo::varbinary_max(),
629            value: Some(ciphertext),
630            crypto_metadata: Some(metadata),
631        }
632    }
633
634    /// Create a NULL Always Encrypted parameter.
635    ///
636    /// The server rejects a plaintext parameter bound to an encrypted column,
637    /// even for NULL, so a NULL value is still sent encrypted: `BIGVARBINARY(max)`
638    /// with the `fEncrypted` status bit, a NULL value, and the cipher metadata.
639    pub fn encrypted_null(name: impl Into<String>, metadata: EncryptedParamMetadata) -> Self {
640        Self {
641            name: name.into(),
642            flags: ParamFlags {
643                encrypted: true,
644                ..ParamFlags::default()
645            },
646            type_info: TypeInfo::varbinary_max(),
647            value: None,
648            crypto_metadata: Some(metadata),
649        }
650    }
651
652    /// Create an INT parameter.
653    pub fn int(name: impl Into<String>, value: i32) -> Self {
654        let mut buf = BytesMut::with_capacity(4);
655        buf.put_i32_le(value);
656        Self::new(name, TypeInfo::int(), buf.freeze())
657    }
658
659    /// Create a BIGINT parameter.
660    pub fn bigint(name: impl Into<String>, value: i64) -> Self {
661        let mut buf = BytesMut::with_capacity(8);
662        buf.put_i64_le(value);
663        Self::new(name, TypeInfo::bigint(), buf.freeze())
664    }
665
666    /// Create an NVARCHAR parameter.
667    pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
668        let mut buf = BytesMut::new();
669        let mut code_units: usize = 0;
670        for code_unit in value.encode_utf16() {
671            buf.put_u16_le(code_unit);
672            code_units += 1;
673        }
674        // NVARCHAR length is measured in UTF-16 code units, not Rust chars β€”
675        // supplementary characters (emoji, CJK extension B) encode to a surrogate
676        // pair (2 code units) but count as 1 `char`. Using chars().count() here
677        // under-reports the buffer length and the server rejects the RPC with
678        // "Data type 0xE7 has an invalid data length or metadata length."
679        let type_info = if code_units > 4000 {
680            TypeInfo::nvarchar_max()
681        } else {
682            TypeInfo::nvarchar(code_units.max(1) as u16)
683        };
684        Self::new(name, type_info, buf.freeze())
685    }
686
687    /// Create a VARCHAR parameter.
688    ///
689    /// Encodes the string as single-byte characters using Windows-1252 encoding
690    /// (when the `encoding` feature is enabled) or Latin-1 fallback. Characters
691    /// not representable in the target encoding are replaced with `?`.
692    ///
693    /// Use this instead of [`nvarchar`](Self::nvarchar) when
694    /// `SendStringParametersAsUnicode=false` to allow SQL Server to use
695    /// index seeks on VARCHAR columns.
696    pub fn varchar(name: impl Into<String>, value: &str) -> Self {
697        let encoded = Self::encode_varchar_bytes(value);
698        let byte_len = encoded.len();
699        let type_info = if byte_len > 8000 {
700            TypeInfo::varchar_max()
701        } else {
702            TypeInfo::varchar(byte_len.max(1) as u16)
703        };
704        Self::new(name, type_info, Bytes::from(encoded))
705    }
706
707    /// Encode a string as single-byte VARCHAR data using the default
708    /// Windows-1252 encoding (or Latin-1 fallback without the `encoding` feature).
709    fn encode_varchar_bytes(value: &str) -> Vec<u8> {
710        crate::collation::encode_str_for_collation(value, None)
711    }
712
713    /// Create a VARCHAR parameter using the server's collation for encoding.
714    ///
715    /// Uses the collation's character encoding instead of the default Windows-1252.
716    /// For UTF-8 collations (SQL Server 2019+), the string bytes are used directly.
717    pub fn varchar_with_collation(
718        name: impl Into<String>,
719        value: &str,
720        collation: &Collation,
721    ) -> Self {
722        let collation_bytes = collation.to_bytes();
723        let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
724        let byte_len = encoded.len();
725        let type_info = if byte_len > 8000 {
726            TypeInfo::varchar_max_with_collation(collation_bytes)
727        } else {
728            TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
729        };
730        Self::new(name, type_info, Bytes::from(encoded))
731    }
732
733    /// Encode a string using the collation's character encoding.
734    fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
735        crate::collation::encode_str_for_collation(value, Some(collation))
736    }
737
738    /// Mark as output parameter.
739    #[must_use]
740    pub fn as_output(mut self) -> Self {
741        self.flags = self.flags.output();
742        self
743    }
744
745    /// Encode the parameter to buffer.
746    pub fn encode(&self, buf: &mut BytesMut) {
747        // Parameter name (B_VARCHAR - length-prefixed)
748        let name_len = self.name.encode_utf16().count() as u8;
749        buf.put_u8(name_len);
750        if name_len > 0 {
751            for code_unit in self.name.encode_utf16() {
752                buf.put_u16_le(code_unit);
753            }
754        }
755
756        // Status flags
757        buf.put_u8(self.flags.encode());
758
759        // Type info
760        self.type_info.encode(buf);
761
762        // Value
763        if let Some(ref value) = self.value {
764            // Length prefix based on type
765            match self.type_info.type_id {
766                0x26 => {
767                    // INTNTYPE
768                    buf.put_u8(value.len() as u8);
769                    buf.put_slice(value);
770                }
771                0x68 | 0x6D | 0x6E | 0x6F => {
772                    // BITNTYPE, FLTNTYPE, MONEYNTYPE, DATETIMENTYPE
773                    buf.put_u8(value.len() as u8);
774                    buf.put_slice(value);
775                }
776                0xE7 | 0xA7 | 0xA5 => {
777                    // NVARCHARTYPE, BIGVARCHARTYPE, BIGVARBINTYPE
778                    if self.type_info.max_length == Some(0xFFFF) {
779                        // MAX type - use PLP format
780                        // For simplicity, send as single chunk
781                        let total_len = value.len() as u64;
782                        buf.put_u64_le(total_len);
783                        buf.put_u32_le(value.len() as u32);
784                        buf.put_slice(value);
785                        buf.put_u32_le(0); // Terminator
786                    } else {
787                        buf.put_u16_le(value.len() as u16);
788                        buf.put_slice(value);
789                    }
790                }
791                0x24 => {
792                    // GUIDTYPE
793                    buf.put_u8(value.len() as u8);
794                    buf.put_slice(value);
795                }
796                0x28..=0x2B => {
797                    // DATE, TIME, DATETIME2, DATETIMEOFFSET
798                    buf.put_u8(value.len() as u8);
799                    buf.put_slice(value);
800                }
801                0x6C => {
802                    // DECIMALNTYPE
803                    buf.put_u8(value.len() as u8);
804                    buf.put_slice(value);
805                }
806                0xF3 => {
807                    // TVP (Table-Valued Parameter)
808                    // TVP values are self-delimiting: they contain complete metadata,
809                    // row data, and end token (TVP_END_TOKEN = 0x00). No length prefix.
810                    buf.put_slice(value);
811                }
812                _ => {
813                    // Generic: assume length-prefixed
814                    buf.put_u8(value.len() as u8);
815                    buf.put_slice(value);
816                }
817            }
818        } else {
819            // NULL value
820            match self.type_info.type_id {
821                0xE7 | 0xA7 | 0xA5 => {
822                    // Variable-length types use 0xFFFF for NULL
823                    if self.type_info.max_length == Some(0xFFFF) {
824                        buf.put_u64_le(0xFFFFFFFFFFFFFFFF); // PLP NULL
825                    } else {
826                        buf.put_u16_le(0xFFFF);
827                    }
828                }
829                _ => {
830                    buf.put_u8(0); // Zero-length for NULL
831                }
832            }
833        }
834
835        // Always Encrypted: the CryptoMetadata trailer follows the value.
836        if let Some(ref metadata) = self.crypto_metadata {
837            metadata.encode(buf);
838        }
839    }
840}
841
842/// RPC request builder.
843#[derive(Debug, Clone)]
844pub struct RpcRequest {
845    /// Procedure name (if using named procedure).
846    proc_name: Option<String>,
847    /// Procedure ID (if using well-known procedure).
848    proc_id: Option<ProcId>,
849    /// Option flags.
850    options: RpcOptionFlags,
851    /// Parameters.
852    params: Vec<RpcParam>,
853}
854
855impl RpcRequest {
856    /// Create a new RPC request for a named procedure.
857    pub fn named(proc_name: impl Into<String>) -> Self {
858        Self {
859            proc_name: Some(proc_name.into()),
860            proc_id: None,
861            options: RpcOptionFlags::default(),
862            params: Vec::new(),
863        }
864    }
865
866    /// Create a new RPC request for a well-known procedure.
867    pub fn by_id(proc_id: ProcId) -> Self {
868        Self {
869            proc_name: None,
870            proc_id: Some(proc_id),
871            options: RpcOptionFlags::default(),
872            params: Vec::new(),
873        }
874    }
875
876    /// Create an sp_executesql request.
877    ///
878    /// This is the primary method for parameterized queries.
879    ///
880    /// # Example
881    ///
882    /// ```
883    /// use tds_protocol::rpc::{RpcRequest, RpcParam};
884    ///
885    /// let rpc = RpcRequest::execute_sql(
886    ///     "SELECT * FROM users WHERE id = @p1 AND name = @p2",
887    ///     vec![
888    ///         RpcParam::int("@p1", 42),
889    ///         RpcParam::nvarchar("@p2", "Alice"),
890    ///     ],
891    /// );
892    /// ```
893    pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
894        let mut request = Self::by_id(ProcId::ExecuteSql);
895
896        // First parameter: the SQL statement (NVARCHAR(MAX))
897        request.params.push(RpcParam::nvarchar("", sql));
898
899        // Second parameter: parameter declarations
900        if !params.is_empty() {
901            let declarations = Self::build_param_declarations(&params);
902            request.params.push(RpcParam::nvarchar("", &declarations));
903        }
904
905        // Add the actual parameters
906        request.params.extend(params);
907
908        request
909    }
910
911    /// Build parameter declaration string for sp_executesql.
912    /// Build the `sp_executesql` `@params` declaration string for `params`.
913    ///
914    /// An Always Encrypted parameter declares its plaintext column type (its
915    /// [`EncryptedParamMetadata::base_type_info`]), not the `BIGVARBINARY`
916    /// transport type its value is carried as, so the declaration matches what
917    /// `sp_describe_parameter_encryption` was asked about.
918    pub fn build_param_declarations(params: &[RpcParam]) -> String {
919        params
920            .iter()
921            .map(|p| {
922                let name = if p.name.starts_with('@') {
923                    p.name.clone()
924                } else if p.name.is_empty() {
925                    // Generate positional name
926                    format!(
927                        "@p{}",
928                        params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
929                    )
930                } else {
931                    format!("@{}", p.name)
932                };
933
934                // Encrypted parameters declare their plaintext type, not the
935                // BIGVARBINARY their ciphertext rides in.
936                let ti = p
937                    .crypto_metadata
938                    .as_ref()
939                    .map(|m| &m.base_type_info)
940                    .unwrap_or(&p.type_info);
941
942                let type_name: String = match ti.type_id {
943                    0x26 => match ti.max_length {
944                        Some(1) => "tinyint".to_string(),
945                        Some(2) => "smallint".to_string(),
946                        Some(4) => "int".to_string(),
947                        Some(8) => "bigint".to_string(),
948                        _ => "int".to_string(),
949                    },
950                    0x68 => "bit".to_string(),
951                    0x6D => match ti.max_length {
952                        Some(4) => "real".to_string(),
953                        _ => "float".to_string(),
954                    },
955                    0xE7 => {
956                        if ti.max_length == Some(0xFFFF) {
957                            "nvarchar(max)".to_string()
958                        } else {
959                            let len = ti.max_length.unwrap_or(4000) / 2;
960                            format!("nvarchar({len})")
961                        }
962                    }
963                    0xA7 => {
964                        if ti.max_length == Some(0xFFFF) {
965                            "varchar(max)".to_string()
966                        } else {
967                            let len = ti.max_length.unwrap_or(8000);
968                            format!("varchar({len})")
969                        }
970                    }
971                    0xA5 => {
972                        if ti.max_length == Some(0xFFFF) {
973                            "varbinary(max)".to_string()
974                        } else {
975                            let len = ti.max_length.unwrap_or(8000);
976                            format!("varbinary({len})")
977                        }
978                    }
979                    0x24 => "uniqueidentifier".to_string(),
980                    0x28 => "date".to_string(),
981                    0x29 => {
982                        let scale = ti.scale.unwrap_or(7);
983                        format!("time({scale})")
984                    }
985                    0x2A => {
986                        let scale = ti.scale.unwrap_or(7);
987                        format!("datetime2({scale})")
988                    }
989                    0x2B => {
990                        let scale = ti.scale.unwrap_or(7);
991                        format!("datetimeoffset({scale})")
992                    }
993                    0x6C => {
994                        let precision = ti.precision.unwrap_or(18);
995                        let scale = ti.scale.unwrap_or(0);
996                        format!("decimal({precision}, {scale})")
997                    }
998                    0x6E => match ti.max_length {
999                        Some(4) => "smallmoney".to_string(),
1000                        _ => "money".to_string(),
1001                    },
1002                    0x6F => match ti.max_length {
1003                        Some(4) => "smalldatetime".to_string(),
1004                        _ => "datetime".to_string(),
1005                    },
1006                    0xF3 => {
1007                        // TVP - Table-Valued Parameter
1008                        // Must be declared with the table type name and READONLY
1009                        if let Some(ref tvp_name) = ti.tvp_type_name {
1010                            format!("{tvp_name} READONLY")
1011                        } else {
1012                            // Fallback if type name is missing (shouldn't happen)
1013                            "sql_variant".to_string()
1014                        }
1015                    }
1016                    _ => "sql_variant".to_string(),
1017                };
1018
1019                format!("{name} {type_name}")
1020            })
1021            .collect::<Vec<_>>()
1022            .join(", ")
1023    }
1024
1025    /// Create an sp_prepare request.
1026    pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
1027        let mut request = Self::by_id(ProcId::Prepare);
1028
1029        // OUT: handle (INT)
1030        request
1031            .params
1032            .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
1033
1034        // Param declarations
1035        let declarations = Self::build_param_declarations(params);
1036        request
1037            .params
1038            .push(RpcParam::nvarchar("@params", &declarations));
1039
1040        // SQL statement
1041        request.params.push(RpcParam::nvarchar("@stmt", sql));
1042
1043        // Options (1 = WITH RECOMPILE)
1044        request.params.push(RpcParam::int("@options", 1));
1045
1046        request
1047    }
1048
1049    /// Create an sp_execute request.
1050    pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
1051        let mut request = Self::by_id(ProcId::Execute);
1052
1053        // Handle from sp_prepare
1054        request.params.push(RpcParam::int("@handle", handle));
1055
1056        // Add parameters
1057        request.params.extend(params);
1058
1059        request
1060    }
1061
1062    /// Create an sp_unprepare request.
1063    pub fn unprepare(handle: i32) -> Self {
1064        let mut request = Self::by_id(ProcId::Unprepare);
1065        request.params.push(RpcParam::int("@handle", handle));
1066        request
1067    }
1068
1069    /// Set option flags.
1070    #[must_use]
1071    pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
1072        self.options = options;
1073        self
1074    }
1075
1076    /// Add a parameter.
1077    #[must_use]
1078    pub fn param(mut self, param: RpcParam) -> Self {
1079        self.params.push(param);
1080        self
1081    }
1082
1083    /// Encode the RPC request to bytes (auto-commit mode).
1084    ///
1085    /// For requests within an explicit transaction, use [`Self::encode_with_transaction`].
1086    #[must_use]
1087    pub fn encode(&self) -> Bytes {
1088        self.encode_with_transaction(0)
1089    }
1090
1091    /// Encode the RPC request with a transaction descriptor.
1092    ///
1093    /// Per MS-TDS spec, when executing within an explicit transaction:
1094    /// - The `transaction_descriptor` MUST be the value returned by the server
1095    ///   in the BeginTransaction EnvChange token.
1096    /// - For auto-commit mode (no explicit transaction), use 0.
1097    ///
1098    /// # Arguments
1099    ///
1100    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
1101    ///   or 0 for auto-commit mode.
1102    #[must_use]
1103    pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
1104        let mut buf = BytesMut::with_capacity(256);
1105
1106        // ALL_HEADERS - TDS 7.2+ requires this section
1107        // Total length placeholder (will be filled in)
1108        let all_headers_start = buf.len();
1109        buf.put_u32_le(0); // Total length placeholder
1110
1111        // Transaction descriptor header (required for RPC)
1112        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
1113        buf.put_u32_le(18); // Header length
1114        buf.put_u16_le(0x0002); // Header type: transaction descriptor
1115        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
1116        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
1117
1118        // Fill in ALL_HEADERS total length
1119        let all_headers_len = buf.len() - all_headers_start;
1120        let len_bytes = (all_headers_len as u32).to_le_bytes();
1121        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1122
1123        // Procedure name or ID
1124        if let Some(proc_id) = self.proc_id {
1125            // Use PROCID format
1126            buf.put_u16_le(0xFFFF); // Name length = 0xFFFF indicates PROCID follows
1127            buf.put_u16_le(proc_id as u16);
1128        } else if let Some(ref proc_name) = self.proc_name {
1129            // Use procedure name
1130            let name_len = proc_name.encode_utf16().count() as u16;
1131            buf.put_u16_le(name_len);
1132            write_utf16_string(&mut buf, proc_name);
1133        }
1134
1135        // Option flags
1136        buf.put_u16_le(self.options.encode());
1137
1138        // Parameters
1139        for param in &self.params {
1140            param.encode(&mut buf);
1141        }
1142
1143        buf.freeze()
1144    }
1145}
1146
1147#[cfg(test)]
1148#[allow(clippy::unwrap_used)]
1149mod tests {
1150    use super::*;
1151
1152    /// The encrypted-RPC-param wire format must match `Microsoft.Data.SqlClient`
1153    /// byte-for-byte. Goldens were captured from a live deterministic Always
1154    /// Encrypted INSERT (`.tmp/ae-wire-encrypted-param.md`): one INT, one
1155    /// NVARCHAR(50), and one VARBINARY(50) parameter, all bound to a single CEK
1156    /// (db_id=6, cek_id=2, cek_version=1, md_version=0x0000b469002f7223). The
1157    /// ciphertext is sliced out of each golden and fed back through the encoder,
1158    /// so a match proves the full framing (fEncrypted status, BIGVARBINARY(max)
1159    /// PLP value, and the CryptoMetadata trailer) reproduces the real client.
1160    #[test]
1161    fn encrypted_param_encode_matches_captured_dotnet_wire() {
1162        fn unhex(s: &str) -> Vec<u8> {
1163            (0..s.len())
1164                .step_by(2)
1165                .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1166                .collect()
1167        }
1168
1169        // @s carries the encrypted column's BIN2 collation in its BaseTypeInfo.
1170        let mut nvarchar_50 = TypeInfo::nvarchar(50);
1171        nvarchar_50.collation = Some([0x09, 0x04, 0xd0, 0x00, 0x34]);
1172
1173        let cases = [
1174            (
1175                "@i",
1176                TypeInfo::int(),
1177                "024000690008a5ffff41000000000000004100000001ed7b8c6030870d92358f12acb0d0c69c00bc3aa3ba578ecb0ea5f514b5045912a2b1ae52ed834f6bac49520956e4a574c30d573590fb3785556c8fe42f87c5b4000000002604020106000000020000000100000023722f0069b4000001",
1178            ),
1179            (
1180                "@s",
1181                nvarchar_50,
1182                "024000730008a5ffff4100000000000000410000000150c0a7dec4d4241c7a4a617007d32d97e7131f8c57a5ad212487891170f12ecb9957fce16389f4728d1c3c65813beeea085ae3fd516d29f84298df3e97f0d05d00000000e764000904d00034020106000000020000000100000023722f0069b4000001",
1183            ),
1184            (
1185                "@b",
1186                TypeInfo::varbinary(50),
1187                "024000620008a5ffff41000000000000004100000001d17165aa6df0155be6b78c6712d3b03870ea394cfed10956cf07fbfa204c4b82cddfa5e2f4fc03335f579e2767657e3067cd9da7d62a07427106b91f747b97da00000000a53200020106000000020000000100000023722f0069b4000001",
1188            ),
1189        ];
1190
1191        for (name, base_type_info, golden_hex) in cases {
1192            let golden = unhex(golden_hex);
1193            // Slice the ciphertext out of the golden: B_VARCHAR name
1194            // (1 + 2*chars) + status(1) + A5FFFF(3) + PLP total(8) + chunk(4).
1195            let cipher_off = 1 + name.encode_utf16().count() * 2 + 1 + 3 + 8 + 4;
1196            let cipher = Bytes::copy_from_slice(&golden[cipher_off..cipher_off + 65]);
1197
1198            let param = RpcParam::encrypted(
1199                name,
1200                cipher,
1201                EncryptedParamMetadata {
1202                    base_type_info,
1203                    algorithm_id: 2,
1204                    encryption_type: EncryptionTypeWire::Deterministic,
1205                    database_id: 6,
1206                    cek_id: 2,
1207                    cek_version: 1,
1208                    cek_md_version: 0x0000_b469_002f_7223,
1209                    normalization_rule_version: 1,
1210                },
1211            );
1212
1213            let mut buf = BytesMut::new();
1214            param.encode(&mut buf);
1215            assert_eq!(
1216                buf.to_vec(),
1217                golden,
1218                "encrypted {name} param must match the captured Microsoft.Data.SqlClient bytes"
1219            );
1220        }
1221    }
1222
1223    #[test]
1224    fn test_proc_id_values() {
1225        assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1226        assert_eq!(ProcId::Prepare as u16, 0x000B);
1227        assert_eq!(ProcId::Execute as u16, 0x000C);
1228        assert_eq!(ProcId::Unprepare as u16, 0x000F);
1229    }
1230
1231    #[test]
1232    fn test_option_flags_encode() {
1233        let flags = RpcOptionFlags::new().with_recompile(true);
1234        assert_eq!(flags.encode(), 0x0001);
1235    }
1236
1237    #[test]
1238    fn test_param_flags_encode() {
1239        let flags = ParamFlags::new().output();
1240        assert_eq!(flags.encode(), 0x01);
1241    }
1242
1243    #[test]
1244    fn test_int_param() {
1245        let param = RpcParam::int("@p1", 42);
1246        assert_eq!(param.name, "@p1");
1247        assert_eq!(param.type_info.type_id, 0x26);
1248        assert!(param.value.is_some());
1249    }
1250
1251    #[test]
1252    fn test_nvarchar_param() {
1253        let param = RpcParam::nvarchar("@name", "Alice");
1254        assert_eq!(param.name, "@name");
1255        assert_eq!(param.type_info.type_id, 0xE7);
1256        // UTF-16 encoded "Alice" = 10 bytes
1257        assert_eq!(param.value.as_ref().unwrap().len(), 10);
1258    }
1259
1260    #[test]
1261    fn test_nvarchar_param_surrogate_pair_length() {
1262        // 🌍 is a supplementary character β€” 1 Rust char but 2 UTF-16 code units
1263        // (4 bytes). TypeInfo.max_length is stored doubled internally, so
1264        // the metadata must declare 2 code units for the buffer to match.
1265        let param = RpcParam::nvarchar("@p", "🌍");
1266        assert_eq!(param.value.as_ref().unwrap().len(), 4);
1267        // TypeInfo::nvarchar(n) stores max_length as n*2 bytes.
1268        assert_eq!(param.type_info.max_length, Some(4));
1269
1270        let param = RpcParam::nvarchar("@p", "Hello δΈ–η•Œ 🌍");
1271        // "Hello δΈ–η•Œ " = 9 BMP code units + 🌍 = 2 surrogate units β†’ 11 code units, 22 bytes
1272        assert_eq!(param.value.as_ref().unwrap().len(), 22);
1273        assert_eq!(param.type_info.max_length, Some(22));
1274    }
1275
1276    #[test]
1277    fn test_execute_sql_request() {
1278        let rpc = RpcRequest::execute_sql(
1279            "SELECT * FROM users WHERE id = @p1",
1280            vec![RpcParam::int("@p1", 42)],
1281        );
1282
1283        assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1284        // SQL statement + param declarations + actual params
1285        assert_eq!(rpc.params.len(), 3);
1286    }
1287
1288    #[test]
1289    fn test_param_declarations() {
1290        let params = vec![
1291            RpcParam::int("@p1", 42),
1292            RpcParam::nvarchar("@name", "Alice"),
1293        ];
1294
1295        let decls = RpcRequest::build_param_declarations(&params);
1296        assert!(decls.contains("@p1 int"));
1297        assert!(decls.contains("@name nvarchar"));
1298    }
1299
1300    #[test]
1301    fn test_rpc_encode_not_empty() {
1302        let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1303        let encoded = rpc.encode();
1304        assert!(!encoded.is_empty());
1305    }
1306
1307    #[test]
1308    fn test_prepare_request() {
1309        let rpc = RpcRequest::prepare(
1310            "SELECT * FROM users WHERE id = @p1",
1311            &[RpcParam::int("@p1", 0)],
1312        );
1313
1314        assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1315        // handle (output), params, stmt, options
1316        assert_eq!(rpc.params.len(), 4);
1317        assert!(rpc.params[0].flags.by_ref); // handle is OUTPUT
1318    }
1319
1320    #[test]
1321    fn test_execute_request() {
1322        let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1323
1324        assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1325        assert_eq!(rpc.params.len(), 2); // handle + param
1326    }
1327
1328    #[test]
1329    fn test_unprepare_request() {
1330        let rpc = RpcRequest::unprepare(123);
1331
1332        assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1333        assert_eq!(rpc.params.len(), 1); // just the handle
1334    }
1335
1336    #[test]
1337    fn test_varchar_param() {
1338        let param = RpcParam::varchar("@name", "Alice");
1339        assert_eq!(param.name, "@name");
1340        assert_eq!(param.type_info.type_id, 0xA7);
1341        // Single-byte encoded "Alice" = 5 bytes
1342        assert_eq!(param.value.as_ref().unwrap().len(), 5);
1343        assert_eq!(&param.value.as_ref().unwrap()[..], b"Alice");
1344    }
1345
1346    #[test]
1347    fn test_varchar_param_max() {
1348        // String > 8000 bytes should use VARCHAR(MAX)
1349        let long_str = "a".repeat(9000);
1350        let param = RpcParam::varchar("@big", &long_str);
1351        assert_eq!(param.type_info.type_id, 0xA7);
1352        assert_eq!(param.type_info.max_length, Some(0xFFFF));
1353        assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1354    }
1355
1356    #[test]
1357    fn test_varchar_param_declarations() {
1358        let params = vec![
1359            RpcParam::int("@p1", 42),
1360            RpcParam::varchar("@name", "Alice"),
1361        ];
1362
1363        let decls = RpcRequest::build_param_declarations(&params);
1364        assert!(decls.contains("@p1 int"));
1365        assert!(decls.contains("@name varchar(5)"));
1366    }
1367
1368    #[test]
1369    fn test_varchar_type_info_has_collation() {
1370        let ti = TypeInfo::varchar(100);
1371        assert_eq!(ti.type_id, 0xA7);
1372        assert_eq!(ti.max_length, Some(100));
1373        assert!(ti.collation.is_some());
1374    }
1375
1376    #[test]
1377    fn test_varchar_encode_round_trip() {
1378        // Verify the encoded param can be serialized without panics
1379        let param = RpcParam::varchar("@val", "test value");
1380        let mut buf = bytes::BytesMut::new();
1381        param.encode(&mut buf);
1382        assert!(!buf.is_empty());
1383    }
1384
1385    #[test]
1386    fn test_collation_round_trip() {
1387        let collation = Collation {
1388            lcid: 0x00D0_0409,
1389            sort_id: 0x34,
1390        };
1391        let bytes = collation.to_bytes();
1392        assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1393
1394        let restored = Collation::from_bytes(&bytes);
1395        assert_eq!(restored.lcid, collation.lcid);
1396        assert_eq!(restored.sort_id, collation.sort_id);
1397    }
1398
1399    #[test]
1400    fn test_varchar_with_collation_uses_custom_collation_bytes() {
1401        // Chinese_PRC_CI_AS collation (LCID 0x0804)
1402        let collation = Collation {
1403            lcid: 0x0804,
1404            sort_id: 0,
1405        };
1406        let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1407        assert_eq!(param.type_info.type_id, 0xA7);
1408        // Collation bytes should match the custom collation, not default Latin1
1409        assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1410    }
1411
1412    #[test]
1413    fn test_money_type_info() {
1414        let ti = TypeInfo::money();
1415        assert_eq!(ti.type_id, 0x6E);
1416        assert_eq!(ti.max_length, Some(8));
1417    }
1418
1419    #[test]
1420    fn test_smallmoney_type_info() {
1421        let ti = TypeInfo::smallmoney();
1422        assert_eq!(ti.type_id, 0x6E);
1423        assert_eq!(ti.max_length, Some(4));
1424    }
1425
1426    #[test]
1427    fn test_smalldatetime_type_info() {
1428        let ti = TypeInfo::smalldatetime();
1429        assert_eq!(ti.type_id, 0x6F);
1430        assert_eq!(ti.max_length, Some(4));
1431    }
1432
1433    #[test]
1434    fn test_money_param_declarations() {
1435        let decls = RpcRequest::build_param_declarations(&[
1436            RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1437            RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1438            RpcParam::new(
1439                "@sdt",
1440                TypeInfo::smalldatetime(),
1441                Bytes::from_static(&[0u8; 4]),
1442            ),
1443        ]);
1444        assert!(decls.contains("@m money"), "got: {decls}");
1445        assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1446        assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1447    }
1448
1449    #[test]
1450    fn test_money_typeinfo_encodes_max_length_byte() {
1451        let mut buf = bytes::BytesMut::new();
1452        TypeInfo::money().encode(&mut buf);
1453        // type_id 0x6E + max_length byte 0x08
1454        assert_eq!(&buf[..], &[0x6E, 0x08]);
1455
1456        let mut buf = bytes::BytesMut::new();
1457        TypeInfo::smallmoney().encode(&mut buf);
1458        assert_eq!(&buf[..], &[0x6E, 0x04]);
1459
1460        let mut buf = bytes::BytesMut::new();
1461        TypeInfo::smalldatetime().encode(&mut buf);
1462        assert_eq!(&buf[..], &[0x6F, 0x04]);
1463    }
1464
1465    #[test]
1466    fn test_varchar_with_collation_default_vs_custom_differ() {
1467        let default_param = RpcParam::varchar("@val", "test");
1468        let custom_collation = Collation {
1469            lcid: 0x0419, // Russian
1470            sort_id: 0,
1471        };
1472        let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1473        // The collation bytes should differ
1474        assert_ne!(
1475            default_param.type_info.collation,
1476            custom_param.type_info.collation
1477        );
1478    }
1479}