Skip to main content

sqlx_sqlserver/
arguments.rs

1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
14const DATA_TYPE_NVARCHAR: u8 = 0xe7;
15const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
16const STATUS_BY_REF_VALUE: u8 = 0x01;
17
18/// SQL Server argument buffer.
19#[derive(Debug, Default, Clone)]
20pub struct MssqlArguments {
21    len: usize,
22    data: Vec<u8>,
23    declarations: String,
24}
25
26impl MssqlArguments {
27    /// Returns `true` when no arguments were added.
28    pub const fn is_empty(&self) -> bool {
29        self.len == 0
30    }
31
32    pub(crate) fn data(&self) -> &[u8] {
33        &self.data
34    }
35
36    pub(crate) fn declarations(&self) -> &str {
37        &self.declarations
38    }
39
40    fn add_parameter(
41        &mut self,
42        type_info: MssqlTypeInfo,
43        encoded: Vec<u8>,
44        is_null: bool,
45    ) -> Result<(), BoxDynError> {
46        self.len += 1;
47        let name = format!("@p{}", self.len);
48
49        if !self.declarations.is_empty() {
50            self.declarations.push(',');
51        }
52
53        write!(
54            self.declarations,
55            "{name} {}",
56            declaration(&type_info, encoded.len(), is_null)?
57        )?;
58
59        write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
60        Ok(())
61    }
62}
63
64impl Arguments for MssqlArguments {
65    type Database = Mssql;
66
67    fn reserve(&mut self, _additional: usize, _size: usize) {}
68
69    fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
70    where
71        T: Encode<'t, Self::Database> + Type<Self::Database>,
72    {
73        let type_info = value.produces().unwrap_or_else(T::type_info);
74        let mut encoded = Vec::with_capacity(value.size_hint());
75        let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
76        self.add_parameter(type_info, encoded, is_null)?;
77        Ok(())
78    }
79
80    fn len(&self) -> usize {
81        self.len
82    }
83
84    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
85        write!(writer, "@p{}", self.len)
86    }
87}
88
89pub(crate) fn write_parameter(
90    out: &mut Vec<u8>,
91    name: &str,
92    type_info: &MssqlTypeInfo,
93    encoded: &[u8],
94    is_null: bool,
95) -> Result<(), BoxDynError> {
96    write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
97}
98
99pub(crate) fn write_output_i32_parameter(
100    out: &mut Vec<u8>,
101    name: &str,
102    value: i32,
103) -> Result<(), BoxDynError> {
104    write_parameter_with_status(
105        out,
106        name,
107        STATUS_BY_REF_VALUE,
108        &MssqlTypeInfo::INT,
109        &value.to_le_bytes(),
110        false,
111    )
112}
113
114fn write_parameter_with_status(
115    out: &mut Vec<u8>,
116    name: &str,
117    status: u8,
118    type_info: &MssqlTypeInfo,
119    encoded: &[u8],
120    is_null: bool,
121) -> Result<(), BoxDynError> {
122    write_b_varchar(out, name)?;
123    out.push(status);
124    write_type_info(out, type_info, encoded.len(), is_null)?;
125    write_param_len_data(out, type_info, encoded, is_null)?;
126    Ok(())
127}
128
129pub(crate) fn write_nvarchar_parameter(
130    out: &mut Vec<u8>,
131    name: &str,
132    value: &str,
133) -> Result<(), BoxDynError> {
134    let mut encoded = Vec::with_capacity(value.len() * 2);
135    write_utf16(&mut encoded, value);
136    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
137}
138
139pub(crate) fn write_null_nvarchar_parameter(
140    out: &mut Vec<u8>,
141    name: &str,
142) -> Result<(), BoxDynError> {
143    write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
144}
145
146pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
147    Ok(match type_info.kind() {
148        MssqlType::Bit => "bit",
149        MssqlType::TinyInt => "tinyint",
150        MssqlType::SmallInt => "smallint",
151        MssqlType::Int => "int",
152        MssqlType::BigInt => "bigint",
153        MssqlType::Real => "real",
154        MssqlType::Float => "float",
155        MssqlType::NVarChar => "nvarchar(max)",
156        MssqlType::VarBinary => "varbinary(max)",
157        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
158    })
159}
160
161fn write_type_info(
162    out: &mut Vec<u8>,
163    type_info: &MssqlTypeInfo,
164    encoded_len: usize,
165    is_null: bool,
166) -> Result<(), BoxDynError> {
167    match type_info.kind() {
168        MssqlType::Bit => {
169            out.push(DATA_TYPE_BITN);
170            out.push(1);
171        }
172        MssqlType::TinyInt => {
173            out.push(DATA_TYPE_INTN);
174            out.push(1);
175        }
176        MssqlType::SmallInt => {
177            out.push(DATA_TYPE_INTN);
178            out.push(2);
179        }
180        MssqlType::Int => {
181            out.push(DATA_TYPE_INTN);
182            out.push(4);
183        }
184        MssqlType::BigInt => {
185            out.push(DATA_TYPE_INTN);
186            out.push(8);
187        }
188        MssqlType::Real => {
189            out.push(DATA_TYPE_FLOATN);
190            out.push(4);
191        }
192        MssqlType::Float => {
193            out.push(DATA_TYPE_FLOATN);
194            out.push(8);
195        }
196        MssqlType::NVarChar => {
197            out.push(DATA_TYPE_NVARCHAR);
198            out.extend_from_slice(&nvarchar_type_size(encoded_len, is_null)?.to_le_bytes());
199            out.extend_from_slice(&DEFAULT_COLLATION);
200        }
201        MssqlType::VarBinary => {
202            out.push(DATA_TYPE_BIGVARBINARY);
203            out.extend_from_slice(&bounded_short_len(encoded_len, is_null)?.to_le_bytes());
204        }
205        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
206    }
207
208    Ok(())
209}
210
211fn write_param_len_data(
212    out: &mut Vec<u8>,
213    type_info: &MssqlTypeInfo,
214    encoded: &[u8],
215    is_null: bool,
216) -> Result<(), BoxDynError> {
217    match type_info.kind() {
218        MssqlType::Bit
219        | MssqlType::TinyInt
220        | MssqlType::SmallInt
221        | MssqlType::Int
222        | MssqlType::BigInt
223        | MssqlType::Real
224        | MssqlType::Float => {
225            out.push(if is_null {
226                0
227            } else {
228                u8::try_from(encoded.len())?
229            });
230        }
231        MssqlType::NVarChar | MssqlType::VarBinary => {
232            let len = if is_null {
233                u16::MAX
234            } else {
235                u16::try_from(encoded.len())?
236            };
237            out.extend_from_slice(&len.to_le_bytes());
238        }
239        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
240    }
241
242    if !is_null {
243        out.extend_from_slice(encoded);
244    }
245
246    Ok(())
247}
248
249fn declaration(
250    type_info: &MssqlTypeInfo,
251    encoded_len: usize,
252    is_null: bool,
253) -> Result<String, BoxDynError> {
254    Ok(match type_info.kind() {
255        MssqlType::Bit => "bit".to_owned(),
256        MssqlType::TinyInt => "tinyint".to_owned(),
257        MssqlType::SmallInt => "smallint".to_owned(),
258        MssqlType::Int => "int".to_owned(),
259        MssqlType::BigInt => "bigint".to_owned(),
260        MssqlType::Real => "real".to_owned(),
261        MssqlType::Float => "float".to_owned(),
262        MssqlType::NVarChar => format!("nvarchar({})", nvarchar_chars(encoded_len, is_null)?),
263        MssqlType::VarBinary => format!("varbinary({})", bounded_short_len(encoded_len, is_null)?),
264        other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
265    })
266}
267
268fn nvarchar_chars(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
269    Ok(nvarchar_type_size(encoded_len, is_null)? / 2)
270}
271
272fn nvarchar_type_size(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
273    let len = if is_null {
274        2
275    } else {
276        std::cmp::max(2, encoded_len)
277    };
278    Ok(u16::try_from(len)?)
279}
280
281fn bounded_short_len(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
282    let len = if is_null {
283        1
284    } else {
285        std::cmp::max(1, encoded_len)
286    };
287    Ok(u16::try_from(len)?)
288}
289
290fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
291    let char_len = value.encode_utf16().count();
292    out.push(u8::try_from(char_len)?);
293    write_utf16(out, value);
294    Ok(())
295}
296
297fn write_utf16(out: &mut Vec<u8>, value: &str) {
298    for unit in value.encode_utf16() {
299        out.extend_from_slice(&unit.to_le_bytes());
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn formats_sql_server_style_placeholders() {
309        let args = MssqlArguments {
310            len: 3,
311            data: Vec::new(),
312            declarations: String::new(),
313        };
314        let mut out = String::new();
315
316        args.format_placeholder(&mut out).unwrap();
317
318        assert_eq!("@p3", out);
319    }
320
321    #[test]
322    fn records_declarations_and_rpc_argument_data() {
323        let mut args = MssqlArguments::default();
324
325        args.add(7_i32).unwrap();
326        args.add("hi").unwrap();
327
328        assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
329        assert!(args
330            .data()
331            .windows(2)
332            .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
333        assert!(args
334            .data()
335            .windows(8)
336            .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
337    }
338
339    #[test]
340    fn declares_lossless_integer_parameter_types() {
341        let mut args = MssqlArguments::default();
342
343        args.add(-5_i8).unwrap();
344        args.add(255_u8).unwrap();
345        args.add(65_535_u16).unwrap();
346        args.add(u32::MAX).unwrap();
347
348        assert_eq!(
349            "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
350            args.declarations()
351        );
352        assert!(args
353            .data()
354            .windows(2)
355            .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
356        assert!(args
357            .data()
358            .windows(2)
359            .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
360    }
361}