Skip to main content

sqlx_sqlserver/
arguments.rs

1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
14const DATA_TYPE_BIGVARCHAR: u8 = 0xa7;
15const DATA_TYPE_NVARCHAR: u8 = 0xe7;
16const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
17const PLP_NULL: u64 = 0xffff_ffff_ffff_ffff;
18const PLP_CHUNK_SIZE: usize = 8192;
19const STATUS_BY_REF_VALUE: u8 = 0x01;
20
21/// SQL Server argument buffer.
22#[derive(Debug, Default, Clone)]
23pub struct MssqlArguments {
24    len: usize,
25    data: Vec<u8>,
26    declarations: String,
27}
28
29impl MssqlArguments {
30    /// Returns `true` when no arguments were added.
31    pub const fn is_empty(&self) -> bool {
32        self.len == 0
33    }
34
35    pub(crate) fn data(&self) -> &[u8] {
36        &self.data
37    }
38
39    pub(crate) fn declarations(&self) -> &str {
40        &self.declarations
41    }
42
43    fn add_parameter(
44        &mut self,
45        type_info: MssqlTypeInfo,
46        encoded: Vec<u8>,
47        is_null: bool,
48    ) -> Result<(), BoxDynError> {
49        self.len += 1;
50        let name = format!("@p{}", self.len);
51
52        if !self.declarations.is_empty() {
53            self.declarations.push(',');
54        }
55
56        write!(
57            self.declarations,
58            "{name} {}",
59            declaration(&type_info, encoded.len(), is_null)?
60        )?;
61
62        write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
63        Ok(())
64    }
65}
66
67impl Arguments for MssqlArguments {
68    type Database = Mssql;
69
70    fn reserve(&mut self, _additional: usize, _size: usize) {}
71
72    fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
73    where
74        T: Encode<'t, Self::Database> + Type<Self::Database>,
75    {
76        let type_info = value.produces().unwrap_or_else(T::type_info);
77        let mut encoded = Vec::with_capacity(value.size_hint());
78        let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
79        self.add_parameter(type_info, encoded, is_null)?;
80        Ok(())
81    }
82
83    fn len(&self) -> usize {
84        self.len
85    }
86
87    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
88        write!(writer, "@p{}", self.len)
89    }
90}
91
92pub(crate) fn write_parameter(
93    out: &mut Vec<u8>,
94    name: &str,
95    type_info: &MssqlTypeInfo,
96    encoded: &[u8],
97    is_null: bool,
98) -> Result<(), BoxDynError> {
99    write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
100}
101
102pub(crate) fn write_output_i32_parameter(
103    out: &mut Vec<u8>,
104    name: &str,
105    value: i32,
106) -> Result<(), BoxDynError> {
107    write_parameter_with_status(
108        out,
109        name,
110        STATUS_BY_REF_VALUE,
111        &MssqlTypeInfo::INT,
112        &value.to_le_bytes(),
113        false,
114    )
115}
116
117fn write_parameter_with_status(
118    out: &mut Vec<u8>,
119    name: &str,
120    status: u8,
121    type_info: &MssqlTypeInfo,
122    encoded: &[u8],
123    is_null: bool,
124) -> Result<(), BoxDynError> {
125    write_b_varchar(out, name)?;
126    out.push(status);
127    write_type_info(out, type_info, encoded.len(), is_null)?;
128    write_param_len_data(out, type_info, encoded, is_null)?;
129    Ok(())
130}
131
132pub(crate) fn write_nvarchar_parameter(
133    out: &mut Vec<u8>,
134    name: &str,
135    value: &str,
136) -> Result<(), BoxDynError> {
137    let mut encoded = Vec::with_capacity(value.len() * 2);
138    write_utf16(&mut encoded, value);
139    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
140}
141
142pub(crate) fn write_null_nvarchar_parameter(
143    out: &mut Vec<u8>,
144    name: &str,
145) -> Result<(), BoxDynError> {
146    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
147}
148
149pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
150    Ok(match type_info.kind() {
151        MssqlType::Bit => "bit",
152        MssqlType::TinyInt => "tinyint",
153        MssqlType::SmallInt => "smallint",
154        MssqlType::Int => "int",
155        MssqlType::BigInt => "bigint",
156        MssqlType::Real => "real",
157        MssqlType::Float => "float",
158        MssqlType::NVarChar => "nvarchar(max)",
159        MssqlType::VarChar => "varchar(max)",
160        MssqlType::VarBinary => "varbinary(max)",
161        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
162    })
163}
164
165fn write_type_info(
166    out: &mut Vec<u8>,
167    type_info: &MssqlTypeInfo,
168    encoded_len: usize,
169    is_null: bool,
170) -> Result<(), BoxDynError> {
171    match type_info.kind() {
172        MssqlType::Bit => {
173            out.push(DATA_TYPE_BITN);
174            out.push(1);
175        }
176        MssqlType::TinyInt => {
177            out.push(DATA_TYPE_INTN);
178            out.push(1);
179        }
180        MssqlType::SmallInt => {
181            out.push(DATA_TYPE_INTN);
182            out.push(2);
183        }
184        MssqlType::Int => {
185            out.push(DATA_TYPE_INTN);
186            out.push(4);
187        }
188        MssqlType::BigInt => {
189            out.push(DATA_TYPE_INTN);
190            out.push(8);
191        }
192        MssqlType::Real => {
193            out.push(DATA_TYPE_FLOATN);
194            out.push(4);
195        }
196        MssqlType::Float => {
197            out.push(DATA_TYPE_FLOATN);
198            out.push(8);
199        }
200        MssqlType::NVarChar => {
201            out.push(DATA_TYPE_NVARCHAR);
202            out.extend_from_slice(
203                &nvarchar_type_size(type_info, encoded_len, is_null)?.to_le_bytes(),
204            );
205            out.extend_from_slice(&DEFAULT_COLLATION);
206        }
207        MssqlType::VarChar => {
208            out.push(DATA_TYPE_BIGVARCHAR);
209            out.extend_from_slice(
210                &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
211            );
212            out.extend_from_slice(&DEFAULT_COLLATION);
213        }
214        MssqlType::VarBinary => {
215            out.push(DATA_TYPE_BIGVARBINARY);
216            out.extend_from_slice(
217                &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
218            );
219        }
220        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
221    }
222
223    Ok(())
224}
225
226fn write_param_len_data(
227    out: &mut Vec<u8>,
228    type_info: &MssqlTypeInfo,
229    encoded: &[u8],
230    is_null: bool,
231) -> Result<(), BoxDynError> {
232    match type_info.kind() {
233        MssqlType::Bit
234        | MssqlType::TinyInt
235        | MssqlType::SmallInt
236        | MssqlType::Int
237        | MssqlType::BigInt
238        | MssqlType::Real
239        | MssqlType::Float => {
240            out.push(if is_null {
241                0
242            } else {
243                u8::try_from(encoded.len())?
244            });
245        }
246        MssqlType::NVarChar | MssqlType::VarChar | MssqlType::VarBinary => {
247            if type_info.size() == Some(u16::MAX) {
248                write_plp_value(out, encoded, is_null)?;
249            } else {
250                let len = if is_null {
251                    u16::MAX
252                } else {
253                    u16::try_from(encoded.len())?
254                };
255                out.extend_from_slice(&len.to_le_bytes());
256            }
257        }
258        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
259    }
260
261    if !is_null && type_info.size() != Some(u16::MAX) {
262        out.extend_from_slice(encoded);
263    }
264
265    Ok(())
266}
267
268fn declaration(
269    type_info: &MssqlTypeInfo,
270    encoded_len: usize,
271    is_null: bool,
272) -> Result<String, BoxDynError> {
273    Ok(match type_info.kind() {
274        MssqlType::Bit => "bit".to_owned(),
275        MssqlType::TinyInt => "tinyint".to_owned(),
276        MssqlType::SmallInt => "smallint".to_owned(),
277        MssqlType::Int => "int".to_owned(),
278        MssqlType::BigInt => "bigint".to_owned(),
279        MssqlType::Real => "real".to_owned(),
280        MssqlType::Float => "float".to_owned(),
281        MssqlType::NVarChar => nvarchar_declaration(type_info, encoded_len, is_null)?,
282        MssqlType::VarChar => varchar_declaration(type_info, encoded_len, is_null)?,
283        MssqlType::VarBinary => varbinary_declaration(type_info, encoded_len, is_null)?,
284        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
285    })
286}
287
288fn nvarchar_declaration(
289    type_info: &MssqlTypeInfo,
290    encoded_len: usize,
291    is_null: bool,
292) -> Result<String, BoxDynError> {
293    let size = nvarchar_type_size(type_info, encoded_len, is_null)?;
294    if size == u16::MAX {
295        Ok("nvarchar(max)".to_owned())
296    } else {
297        Ok(format!("nvarchar({})", size / 2))
298    }
299}
300
301fn varchar_declaration(
302    type_info: &MssqlTypeInfo,
303    encoded_len: usize,
304    is_null: bool,
305) -> Result<String, BoxDynError> {
306    let size = bounded_short_len(type_info, encoded_len, is_null)?;
307    if size == u16::MAX {
308        Ok("varchar(max)".to_owned())
309    } else {
310        Ok(format!("varchar({size})"))
311    }
312}
313
314fn varbinary_declaration(
315    type_info: &MssqlTypeInfo,
316    encoded_len: usize,
317    is_null: bool,
318) -> Result<String, BoxDynError> {
319    let size = bounded_short_len(type_info, encoded_len, is_null)?;
320    if size == u16::MAX {
321        Ok("varbinary(max)".to_owned())
322    } else {
323        Ok(format!("varbinary({size})"))
324    }
325}
326
327fn nvarchar_type_size(
328    type_info: &MssqlTypeInfo,
329    encoded_len: usize,
330    is_null: bool,
331) -> Result<u16, BoxDynError> {
332    if let Some(size) = type_info.size() {
333        return Ok(size);
334    }
335
336    let len = if is_null {
337        2
338    } else {
339        std::cmp::max(2, encoded_len)
340    };
341    Ok(u16::try_from(len)?)
342}
343
344fn bounded_short_len(
345    type_info: &MssqlTypeInfo,
346    encoded_len: usize,
347    is_null: bool,
348) -> Result<u16, BoxDynError> {
349    if let Some(size) = type_info.size() {
350        return Ok(size);
351    }
352
353    let len = if is_null {
354        1
355    } else {
356        std::cmp::max(1, encoded_len)
357    };
358    Ok(u16::try_from(len)?)
359}
360
361fn write_plp_value(out: &mut Vec<u8>, encoded: &[u8], is_null: bool) -> Result<(), BoxDynError> {
362    if is_null {
363        out.extend_from_slice(&PLP_NULL.to_le_bytes());
364        return Ok(());
365    }
366
367    out.extend_from_slice(&u64::try_from(encoded.len())?.to_le_bytes());
368
369    for chunk in encoded.chunks(PLP_CHUNK_SIZE) {
370        out.extend_from_slice(&u32::try_from(chunk.len())?.to_le_bytes());
371        out.extend_from_slice(chunk);
372    }
373
374    out.extend_from_slice(&0_u32.to_le_bytes());
375    Ok(())
376}
377
378fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
379    let char_len = value.encode_utf16().count();
380    out.push(u8::try_from(char_len)?);
381    write_utf16(out, value);
382    Ok(())
383}
384
385fn write_utf16(out: &mut Vec<u8>, value: &str) {
386    for unit in value.encode_utf16() {
387        out.extend_from_slice(&unit.to_le_bytes());
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn formats_sql_server_style_placeholders() {
397        let args = MssqlArguments {
398            len: 3,
399            data: Vec::new(),
400            declarations: String::new(),
401        };
402        let mut out = String::new();
403
404        args.format_placeholder(&mut out).unwrap();
405
406        assert_eq!("@p3", out);
407    }
408
409    #[test]
410    fn records_declarations_and_rpc_argument_data() {
411        let mut args = MssqlArguments::default();
412
413        args.add(7_i32).unwrap();
414        args.add("hi").unwrap();
415
416        assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
417        assert!(args
418            .data()
419            .windows(2)
420            .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
421        assert!(args
422            .data()
423            .windows(8)
424            .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
425    }
426
427    #[test]
428    fn declares_lossless_integer_parameter_types() {
429        let mut args = MssqlArguments::default();
430
431        args.add(-5_i8).unwrap();
432        args.add(255_u8).unwrap();
433        args.add(65_535_u16).unwrap();
434        args.add(u32::MAX).unwrap();
435
436        assert_eq!(
437            "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
438            args.declarations()
439        );
440        assert!(args
441            .data()
442            .windows(2)
443            .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
444        assert!(args
445            .data()
446            .windows(2)
447            .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
448    }
449
450    #[test]
451    fn declares_large_text_and_binary_parameters_as_max() {
452        let mut args = MssqlArguments::default();
453        let text = "x".repeat(4001);
454        let bytes = vec![0x5a; 8001];
455
456        args.add(text.as_str()).unwrap();
457        args.add(bytes.as_slice()).unwrap();
458
459        assert_eq!("@p1 nvarchar(max),@p2 varbinary(max)", args.declarations());
460        assert!(args
461            .data()
462            .windows(3)
463            .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 0xff, 0xff]));
464        assert!(args
465            .data()
466            .windows(3)
467            .any(|bytes| bytes == [DATA_TYPE_BIGVARBINARY, 0xff, 0xff]));
468        assert!(args
469            .data()
470            .windows(8)
471            .any(|bytes| bytes == 8002_u64.to_le_bytes()));
472        assert!(args
473            .data()
474            .windows(8)
475            .any(|bytes| bytes == 8001_u64.to_le_bytes()));
476    }
477}