1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
14const DATA_TYPE_NVARCHAR: u8 = 0xe7;
15const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
16const STATUS_BY_REF_VALUE: u8 = 0x01;
17
18#[derive(Debug, Default, Clone)]
20pub struct MssqlArguments {
21 len: usize,
22 data: Vec<u8>,
23 declarations: String,
24}
25
26impl MssqlArguments {
27 pub const fn is_empty(&self) -> bool {
29 self.len == 0
30 }
31
32 pub(crate) fn data(&self) -> &[u8] {
33 &self.data
34 }
35
36 pub(crate) fn declarations(&self) -> &str {
37 &self.declarations
38 }
39
40 fn add_parameter(
41 &mut self,
42 type_info: MssqlTypeInfo,
43 encoded: Vec<u8>,
44 is_null: bool,
45 ) -> Result<(), BoxDynError> {
46 self.len += 1;
47 let name = format!("@p{}", self.len);
48
49 if !self.declarations.is_empty() {
50 self.declarations.push(',');
51 }
52
53 write!(
54 self.declarations,
55 "{name} {}",
56 declaration(&type_info, encoded.len(), is_null)?
57 )?;
58
59 write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
60 Ok(())
61 }
62}
63
64impl Arguments for MssqlArguments {
65 type Database = Mssql;
66
67 fn reserve(&mut self, _additional: usize, _size: usize) {}
68
69 fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
70 where
71 T: Encode<'t, Self::Database> + Type<Self::Database>,
72 {
73 let type_info = value.produces().unwrap_or_else(T::type_info);
74 let mut encoded = Vec::with_capacity(value.size_hint());
75 let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
76 self.add_parameter(type_info, encoded, is_null)?;
77 Ok(())
78 }
79
80 fn len(&self) -> usize {
81 self.len
82 }
83
84 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
85 write!(writer, "@p{}", self.len)
86 }
87}
88
89pub(crate) fn write_parameter(
90 out: &mut Vec<u8>,
91 name: &str,
92 type_info: &MssqlTypeInfo,
93 encoded: &[u8],
94 is_null: bool,
95) -> Result<(), BoxDynError> {
96 write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
97}
98
99pub(crate) fn write_output_i32_parameter(
100 out: &mut Vec<u8>,
101 name: &str,
102 value: i32,
103) -> Result<(), BoxDynError> {
104 write_parameter_with_status(
105 out,
106 name,
107 STATUS_BY_REF_VALUE,
108 &MssqlTypeInfo::INT,
109 &value.to_le_bytes(),
110 false,
111 )
112}
113
114fn write_parameter_with_status(
115 out: &mut Vec<u8>,
116 name: &str,
117 status: u8,
118 type_info: &MssqlTypeInfo,
119 encoded: &[u8],
120 is_null: bool,
121) -> Result<(), BoxDynError> {
122 write_b_varchar(out, name)?;
123 out.push(status);
124 write_type_info(out, type_info, encoded.len(), is_null)?;
125 write_param_len_data(out, type_info, encoded, is_null)?;
126 Ok(())
127}
128
129pub(crate) fn write_nvarchar_parameter(
130 out: &mut Vec<u8>,
131 name: &str,
132 value: &str,
133) -> Result<(), BoxDynError> {
134 let mut encoded = Vec::with_capacity(value.len() * 2);
135 write_utf16(&mut encoded, value);
136 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
137}
138
139pub(crate) fn write_null_nvarchar_parameter(
140 out: &mut Vec<u8>,
141 name: &str,
142) -> Result<(), BoxDynError> {
143 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
144}
145
146pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
147 Ok(match type_info.kind() {
148 MssqlType::Bit => "bit",
149 MssqlType::TinyInt => "tinyint",
150 MssqlType::SmallInt => "smallint",
151 MssqlType::Int => "int",
152 MssqlType::BigInt => "bigint",
153 MssqlType::Real => "real",
154 MssqlType::Float => "float",
155 MssqlType::NVarChar => "nvarchar(max)",
156 MssqlType::VarBinary => "varbinary(max)",
157 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
158 })
159}
160
161fn write_type_info(
162 out: &mut Vec<u8>,
163 type_info: &MssqlTypeInfo,
164 encoded_len: usize,
165 is_null: bool,
166) -> Result<(), BoxDynError> {
167 match type_info.kind() {
168 MssqlType::Bit => {
169 out.push(DATA_TYPE_BITN);
170 out.push(1);
171 }
172 MssqlType::TinyInt => {
173 out.push(DATA_TYPE_INTN);
174 out.push(1);
175 }
176 MssqlType::SmallInt => {
177 out.push(DATA_TYPE_INTN);
178 out.push(2);
179 }
180 MssqlType::Int => {
181 out.push(DATA_TYPE_INTN);
182 out.push(4);
183 }
184 MssqlType::BigInt => {
185 out.push(DATA_TYPE_INTN);
186 out.push(8);
187 }
188 MssqlType::Real => {
189 out.push(DATA_TYPE_FLOATN);
190 out.push(4);
191 }
192 MssqlType::Float => {
193 out.push(DATA_TYPE_FLOATN);
194 out.push(8);
195 }
196 MssqlType::NVarChar => {
197 out.push(DATA_TYPE_NVARCHAR);
198 out.extend_from_slice(&nvarchar_type_size(encoded_len, is_null)?.to_le_bytes());
199 out.extend_from_slice(&DEFAULT_COLLATION);
200 }
201 MssqlType::VarBinary => {
202 out.push(DATA_TYPE_BIGVARBINARY);
203 out.extend_from_slice(&bounded_short_len(encoded_len, is_null)?.to_le_bytes());
204 }
205 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
206 }
207
208 Ok(())
209}
210
211fn write_param_len_data(
212 out: &mut Vec<u8>,
213 type_info: &MssqlTypeInfo,
214 encoded: &[u8],
215 is_null: bool,
216) -> Result<(), BoxDynError> {
217 match type_info.kind() {
218 MssqlType::Bit
219 | MssqlType::TinyInt
220 | MssqlType::SmallInt
221 | MssqlType::Int
222 | MssqlType::BigInt
223 | MssqlType::Real
224 | MssqlType::Float => {
225 out.push(if is_null {
226 0
227 } else {
228 u8::try_from(encoded.len())?
229 });
230 }
231 MssqlType::NVarChar | MssqlType::VarBinary => {
232 let len = if is_null {
233 u16::MAX
234 } else {
235 u16::try_from(encoded.len())?
236 };
237 out.extend_from_slice(&len.to_le_bytes());
238 }
239 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
240 }
241
242 if !is_null {
243 out.extend_from_slice(encoded);
244 }
245
246 Ok(())
247}
248
249fn declaration(
250 type_info: &MssqlTypeInfo,
251 encoded_len: usize,
252 is_null: bool,
253) -> Result<String, BoxDynError> {
254 Ok(match type_info.kind() {
255 MssqlType::Bit => "bit".to_owned(),
256 MssqlType::TinyInt => "tinyint".to_owned(),
257 MssqlType::SmallInt => "smallint".to_owned(),
258 MssqlType::Int => "int".to_owned(),
259 MssqlType::BigInt => "bigint".to_owned(),
260 MssqlType::Real => "real".to_owned(),
261 MssqlType::Float => "float".to_owned(),
262 MssqlType::NVarChar => format!("nvarchar({})", nvarchar_chars(encoded_len, is_null)?),
263 MssqlType::VarBinary => format!("varbinary({})", bounded_short_len(encoded_len, is_null)?),
264 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
265 })
266}
267
268fn nvarchar_chars(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
269 Ok(nvarchar_type_size(encoded_len, is_null)? / 2)
270}
271
272fn nvarchar_type_size(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
273 let len = if is_null {
274 2
275 } else {
276 std::cmp::max(2, encoded_len)
277 };
278 Ok(u16::try_from(len)?)
279}
280
281fn bounded_short_len(encoded_len: usize, is_null: bool) -> Result<u16, BoxDynError> {
282 let len = if is_null {
283 1
284 } else {
285 std::cmp::max(1, encoded_len)
286 };
287 Ok(u16::try_from(len)?)
288}
289
290fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
291 let char_len = value.encode_utf16().count();
292 out.push(u8::try_from(char_len)?);
293 write_utf16(out, value);
294 Ok(())
295}
296
297fn write_utf16(out: &mut Vec<u8>, value: &str) {
298 for unit in value.encode_utf16() {
299 out.extend_from_slice(&unit.to_le_bytes());
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn formats_sql_server_style_placeholders() {
309 let args = MssqlArguments {
310 len: 3,
311 data: Vec::new(),
312 declarations: String::new(),
313 };
314 let mut out = String::new();
315
316 args.format_placeholder(&mut out).unwrap();
317
318 assert_eq!("@p3", out);
319 }
320
321 #[test]
322 fn records_declarations_and_rpc_argument_data() {
323 let mut args = MssqlArguments::default();
324
325 args.add(7_i32).unwrap();
326 args.add("hi").unwrap();
327
328 assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
329 assert!(args
330 .data()
331 .windows(2)
332 .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
333 assert!(args
334 .data()
335 .windows(8)
336 .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
337 }
338
339 #[test]
340 fn declares_lossless_integer_parameter_types() {
341 let mut args = MssqlArguments::default();
342
343 args.add(-5_i8).unwrap();
344 args.add(255_u8).unwrap();
345 args.add(65_535_u16).unwrap();
346 args.add(u32::MAX).unwrap();
347
348 assert_eq!(
349 "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
350 args.declarations()
351 );
352 assert!(args
353 .data()
354 .windows(2)
355 .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
356 assert!(args
357 .data()
358 .windows(2)
359 .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
360 }
361}