Skip to main content

sqlx_sqlserver/
arguments.rs

1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_GUID: u8 = 0x24;
14const DATA_TYPE_NUMERICN: u8 = 0x6c;
15const DATA_TYPE_DATEN: u8 = 0x28;
16const DATA_TYPE_TIMEN: u8 = 0x29;
17const DATA_TYPE_DATETIME2N: u8 = 0x2a;
18const DATA_TYPE_DATETIMEOFFSETN: u8 = 0x2b;
19const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
20const DATA_TYPE_BIGVARCHAR: u8 = 0xa7;
21const DATA_TYPE_NVARCHAR: u8 = 0xe7;
22const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
23const PLP_NULL: u64 = 0xffff_ffff_ffff_ffff;
24const PLP_CHUNK_SIZE: usize = 8192;
25const STATUS_BY_REF_VALUE: u8 = 0x01;
26
27/// SQL Server argument buffer.
28#[derive(Debug, Default, Clone)]
29pub struct MssqlArguments {
30    len: usize,
31    data: Vec<u8>,
32    declarations: String,
33}
34
35impl MssqlArguments {
36    /// Returns `true` when no arguments were added.
37    pub const fn is_empty(&self) -> bool {
38        self.len == 0
39    }
40
41    pub(crate) fn data(&self) -> &[u8] {
42        &self.data
43    }
44
45    pub(crate) fn declarations(&self) -> &str {
46        &self.declarations
47    }
48
49    fn add_parameter(
50        &mut self,
51        type_info: MssqlTypeInfo,
52        encoded: Vec<u8>,
53        is_null: bool,
54    ) -> Result<(), BoxDynError> {
55        self.len += 1;
56        let name = format!("@p{}", self.len);
57
58        if !self.declarations.is_empty() {
59            self.declarations.push(',');
60        }
61
62        write!(
63            self.declarations,
64            "{name} {}",
65            declaration(&type_info, encoded.len(), is_null)?
66        )?;
67
68        write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
69        Ok(())
70    }
71}
72
73impl Arguments for MssqlArguments {
74    type Database = Mssql;
75
76    fn reserve(&mut self, _additional: usize, _size: usize) {}
77
78    fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
79    where
80        T: Encode<'t, Self::Database> + Type<Self::Database>,
81    {
82        let type_info = value.produces().unwrap_or_else(T::type_info);
83        let mut encoded = Vec::with_capacity(value.size_hint());
84        let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
85        self.add_parameter(type_info, encoded, is_null)?;
86        Ok(())
87    }
88
89    fn len(&self) -> usize {
90        self.len
91    }
92
93    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
94        write!(writer, "@p{}", self.len)
95    }
96}
97
98pub(crate) fn write_parameter(
99    out: &mut Vec<u8>,
100    name: &str,
101    type_info: &MssqlTypeInfo,
102    encoded: &[u8],
103    is_null: bool,
104) -> Result<(), BoxDynError> {
105    write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
106}
107
108pub(crate) fn write_output_i32_parameter(
109    out: &mut Vec<u8>,
110    name: &str,
111    value: i32,
112) -> Result<(), BoxDynError> {
113    write_parameter_with_status(
114        out,
115        name,
116        STATUS_BY_REF_VALUE,
117        &MssqlTypeInfo::INT,
118        &value.to_le_bytes(),
119        false,
120    )
121}
122
123fn write_parameter_with_status(
124    out: &mut Vec<u8>,
125    name: &str,
126    status: u8,
127    type_info: &MssqlTypeInfo,
128    encoded: &[u8],
129    is_null: bool,
130) -> Result<(), BoxDynError> {
131    write_b_varchar(out, name)?;
132    out.push(status);
133    write_type_info(out, type_info, encoded.len(), is_null)?;
134    write_param_len_data(out, type_info, encoded, is_null)?;
135    Ok(())
136}
137
138pub(crate) fn write_nvarchar_parameter(
139    out: &mut Vec<u8>,
140    name: &str,
141    value: &str,
142) -> Result<(), BoxDynError> {
143    let mut encoded = Vec::with_capacity(value.len() * 2);
144    write_utf16(&mut encoded, value);
145    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
146}
147
148pub(crate) fn write_null_nvarchar_parameter(
149    out: &mut Vec<u8>,
150    name: &str,
151) -> Result<(), BoxDynError> {
152    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
153}
154
155pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
156    Ok(match type_info.kind() {
157        MssqlType::Bit => "bit",
158        MssqlType::TinyInt => "tinyint",
159        MssqlType::SmallInt => "smallint",
160        MssqlType::Int => "int",
161        MssqlType::BigInt => "bigint",
162        MssqlType::Real => "real",
163        MssqlType::Float => "float",
164        MssqlType::NVarChar => "nvarchar(max)",
165        MssqlType::VarChar => "varchar(max)",
166        MssqlType::VarBinary => "varbinary(max)",
167        MssqlType::Decimal => "decimal(38, 0)",
168        MssqlType::Date => "date",
169        MssqlType::Time => "time(7)",
170        MssqlType::DateTime2 => "datetime2(7)",
171        MssqlType::DateTimeOffset => "datetimeoffset(7)",
172        MssqlType::UniqueIdentifier => "uniqueidentifier",
173        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
174    })
175}
176
177fn write_type_info(
178    out: &mut Vec<u8>,
179    type_info: &MssqlTypeInfo,
180    encoded_len: usize,
181    is_null: bool,
182) -> Result<(), BoxDynError> {
183    match type_info.kind() {
184        MssqlType::Bit => {
185            out.push(DATA_TYPE_BITN);
186            out.push(1);
187        }
188        MssqlType::TinyInt => {
189            out.push(DATA_TYPE_INTN);
190            out.push(1);
191        }
192        MssqlType::SmallInt => {
193            out.push(DATA_TYPE_INTN);
194            out.push(2);
195        }
196        MssqlType::Int => {
197            out.push(DATA_TYPE_INTN);
198            out.push(4);
199        }
200        MssqlType::BigInt => {
201            out.push(DATA_TYPE_INTN);
202            out.push(8);
203        }
204        MssqlType::Real => {
205            out.push(DATA_TYPE_FLOATN);
206            out.push(4);
207        }
208        MssqlType::Float => {
209            out.push(DATA_TYPE_FLOATN);
210            out.push(8);
211        }
212        MssqlType::NVarChar => {
213            out.push(DATA_TYPE_NVARCHAR);
214            out.extend_from_slice(
215                &nvarchar_type_size(type_info, encoded_len, is_null)?.to_le_bytes(),
216            );
217            out.extend_from_slice(&DEFAULT_COLLATION);
218        }
219        MssqlType::VarChar => {
220            out.push(DATA_TYPE_BIGVARCHAR);
221            out.extend_from_slice(
222                &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
223            );
224            out.extend_from_slice(&DEFAULT_COLLATION);
225        }
226        MssqlType::VarBinary => {
227            out.push(DATA_TYPE_BIGVARBINARY);
228            out.extend_from_slice(
229                &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
230            );
231        }
232        MssqlType::Decimal => {
233            out.push(DATA_TYPE_NUMERICN);
234            out.push(u8::try_from(type_info.size().unwrap_or(17))?);
235            out.push(type_info.precision().max(1));
236            out.push(type_info.scale());
237        }
238        MssqlType::Date => {
239            out.push(DATA_TYPE_DATEN);
240        }
241        MssqlType::Time => {
242            out.push(DATA_TYPE_TIMEN);
243            out.push(type_info.scale());
244        }
245        MssqlType::DateTime2 => {
246            out.push(DATA_TYPE_DATETIME2N);
247            out.push(type_info.scale());
248        }
249        MssqlType::DateTimeOffset => {
250            out.push(DATA_TYPE_DATETIMEOFFSETN);
251            out.push(type_info.scale());
252        }
253        MssqlType::UniqueIdentifier => {
254            out.push(DATA_TYPE_GUID);
255            out.push(16);
256        }
257        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
258    }
259
260    Ok(())
261}
262
263fn write_param_len_data(
264    out: &mut Vec<u8>,
265    type_info: &MssqlTypeInfo,
266    encoded: &[u8],
267    is_null: bool,
268) -> Result<(), BoxDynError> {
269    match type_info.kind() {
270        MssqlType::Bit
271        | MssqlType::TinyInt
272        | MssqlType::SmallInt
273        | MssqlType::Int
274        | MssqlType::BigInt
275        | MssqlType::Real
276        | MssqlType::Float
277        | MssqlType::Decimal
278        | MssqlType::Date
279        | MssqlType::Time
280        | MssqlType::DateTime2
281        | MssqlType::DateTimeOffset
282        | MssqlType::UniqueIdentifier => {
283            out.push(if is_null {
284                0
285            } else {
286                u8::try_from(encoded.len())?
287            });
288        }
289        MssqlType::NVarChar | MssqlType::VarChar | MssqlType::VarBinary => {
290            if type_info.size() == Some(u16::MAX) {
291                write_plp_value(out, encoded, is_null)?;
292            } else {
293                let len = if is_null {
294                    u16::MAX
295                } else {
296                    u16::try_from(encoded.len())?
297                };
298                out.extend_from_slice(&len.to_le_bytes());
299            }
300        }
301        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
302    }
303
304    if !is_null && type_info.size() != Some(u16::MAX) {
305        out.extend_from_slice(encoded);
306    }
307
308    Ok(())
309}
310
311fn declaration(
312    type_info: &MssqlTypeInfo,
313    encoded_len: usize,
314    is_null: bool,
315) -> Result<String, BoxDynError> {
316    Ok(match type_info.kind() {
317        MssqlType::Bit => "bit".to_owned(),
318        MssqlType::TinyInt => "tinyint".to_owned(),
319        MssqlType::SmallInt => "smallint".to_owned(),
320        MssqlType::Int => "int".to_owned(),
321        MssqlType::BigInt => "bigint".to_owned(),
322        MssqlType::Real => "real".to_owned(),
323        MssqlType::Float => "float".to_owned(),
324        MssqlType::NVarChar => nvarchar_declaration(type_info, encoded_len, is_null)?,
325        MssqlType::VarChar => varchar_declaration(type_info, encoded_len, is_null)?,
326        MssqlType::VarBinary => varbinary_declaration(type_info, encoded_len, is_null)?,
327        MssqlType::Decimal => format!(
328            "decimal({},{})",
329            type_info.precision().max(1),
330            type_info.scale()
331        ),
332        MssqlType::Date => "date".to_owned(),
333        MssqlType::Time => format!("time({})", type_info.scale()),
334        MssqlType::DateTime2 => format!("datetime2({})", type_info.scale()),
335        MssqlType::DateTimeOffset => format!("datetimeoffset({})", type_info.scale()),
336        MssqlType::UniqueIdentifier => "uniqueidentifier".to_owned(),
337        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
338    })
339}
340
341fn nvarchar_declaration(
342    type_info: &MssqlTypeInfo,
343    encoded_len: usize,
344    is_null: bool,
345) -> Result<String, BoxDynError> {
346    let size = nvarchar_type_size(type_info, encoded_len, is_null)?;
347    if size == u16::MAX {
348        Ok("nvarchar(max)".to_owned())
349    } else {
350        Ok(format!("nvarchar({})", size / 2))
351    }
352}
353
354fn varchar_declaration(
355    type_info: &MssqlTypeInfo,
356    encoded_len: usize,
357    is_null: bool,
358) -> Result<String, BoxDynError> {
359    let size = bounded_short_len(type_info, encoded_len, is_null)?;
360    if size == u16::MAX {
361        Ok("varchar(max)".to_owned())
362    } else {
363        Ok(format!("varchar({size})"))
364    }
365}
366
367fn varbinary_declaration(
368    type_info: &MssqlTypeInfo,
369    encoded_len: usize,
370    is_null: bool,
371) -> Result<String, BoxDynError> {
372    let size = bounded_short_len(type_info, encoded_len, is_null)?;
373    if size == u16::MAX {
374        Ok("varbinary(max)".to_owned())
375    } else {
376        Ok(format!("varbinary({size})"))
377    }
378}
379
380fn nvarchar_type_size(
381    type_info: &MssqlTypeInfo,
382    encoded_len: usize,
383    is_null: bool,
384) -> Result<u16, BoxDynError> {
385    if let Some(size) = type_info.size() {
386        return Ok(size);
387    }
388
389    let len = if is_null {
390        2
391    } else {
392        std::cmp::max(2, encoded_len)
393    };
394    Ok(u16::try_from(len)?)
395}
396
397fn bounded_short_len(
398    type_info: &MssqlTypeInfo,
399    encoded_len: usize,
400    is_null: bool,
401) -> Result<u16, BoxDynError> {
402    if let Some(size) = type_info.size() {
403        return Ok(size);
404    }
405
406    let len = if is_null {
407        1
408    } else {
409        std::cmp::max(1, encoded_len)
410    };
411    Ok(u16::try_from(len)?)
412}
413
414fn write_plp_value(out: &mut Vec<u8>, encoded: &[u8], is_null: bool) -> Result<(), BoxDynError> {
415    if is_null {
416        out.extend_from_slice(&PLP_NULL.to_le_bytes());
417        return Ok(());
418    }
419
420    out.extend_from_slice(&u64::try_from(encoded.len())?.to_le_bytes());
421
422    for chunk in encoded.chunks(PLP_CHUNK_SIZE) {
423        out.extend_from_slice(&u32::try_from(chunk.len())?.to_le_bytes());
424        out.extend_from_slice(chunk);
425    }
426
427    out.extend_from_slice(&0_u32.to_le_bytes());
428    Ok(())
429}
430
431fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
432    let char_len = value.encode_utf16().count();
433    out.push(u8::try_from(char_len)?);
434    write_utf16(out, value);
435    Ok(())
436}
437
438fn write_utf16(out: &mut Vec<u8>, value: &str) {
439    for unit in value.encode_utf16() {
440        out.extend_from_slice(&unit.to_le_bytes());
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn formats_sql_server_style_placeholders() {
450        let args = MssqlArguments {
451            len: 3,
452            data: Vec::new(),
453            declarations: String::new(),
454        };
455        let mut out = String::new();
456
457        args.format_placeholder(&mut out).unwrap();
458
459        assert_eq!("@p3", out);
460    }
461
462    #[test]
463    fn records_declarations_and_rpc_argument_data() {
464        let mut args = MssqlArguments::default();
465
466        args.add(7_i32).unwrap();
467        args.add("hi").unwrap();
468
469        assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
470        assert!(args
471            .data()
472            .windows(2)
473            .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
474        assert!(args
475            .data()
476            .windows(8)
477            .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
478    }
479
480    #[test]
481    fn declares_lossless_integer_parameter_types() {
482        let mut args = MssqlArguments::default();
483
484        args.add(-5_i8).unwrap();
485        args.add(255_u8).unwrap();
486        args.add(65_535_u16).unwrap();
487        args.add(u32::MAX).unwrap();
488
489        assert_eq!(
490            "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
491            args.declarations()
492        );
493        assert!(args
494            .data()
495            .windows(2)
496            .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
497        assert!(args
498            .data()
499            .windows(2)
500            .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
501    }
502
503    #[test]
504    fn declares_large_text_and_binary_parameters_as_max() {
505        let mut args = MssqlArguments::default();
506        let text = "x".repeat(4001);
507        let bytes = vec![0x5a; 8001];
508
509        args.add(text.as_str()).unwrap();
510        args.add(bytes.as_slice()).unwrap();
511
512        assert_eq!("@p1 nvarchar(max),@p2 varbinary(max)", args.declarations());
513        assert!(args
514            .data()
515            .windows(3)
516            .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 0xff, 0xff]));
517        assert!(args
518            .data()
519            .windows(3)
520            .any(|bytes| bytes == [DATA_TYPE_BIGVARBINARY, 0xff, 0xff]));
521        assert!(args
522            .data()
523            .windows(8)
524            .any(|bytes| bytes == 8002_u64.to_le_bytes()));
525        assert!(args
526            .data()
527            .windows(8)
528            .any(|bytes| bytes == 8001_u64.to_le_bytes()));
529    }
530}