1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
14const DATA_TYPE_BIGVARCHAR: u8 = 0xa7;
15const DATA_TYPE_NVARCHAR: u8 = 0xe7;
16const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
17const PLP_NULL: u64 = 0xffff_ffff_ffff_ffff;
18const PLP_CHUNK_SIZE: usize = 8192;
19const STATUS_BY_REF_VALUE: u8 = 0x01;
20
21#[derive(Debug, Default, Clone)]
23pub struct MssqlArguments {
24 len: usize,
25 data: Vec<u8>,
26 declarations: String,
27}
28
29impl MssqlArguments {
30 pub const fn is_empty(&self) -> bool {
32 self.len == 0
33 }
34
35 pub(crate) fn data(&self) -> &[u8] {
36 &self.data
37 }
38
39 pub(crate) fn declarations(&self) -> &str {
40 &self.declarations
41 }
42
43 fn add_parameter(
44 &mut self,
45 type_info: MssqlTypeInfo,
46 encoded: Vec<u8>,
47 is_null: bool,
48 ) -> Result<(), BoxDynError> {
49 self.len += 1;
50 let name = format!("@p{}", self.len);
51
52 if !self.declarations.is_empty() {
53 self.declarations.push(',');
54 }
55
56 write!(
57 self.declarations,
58 "{name} {}",
59 declaration(&type_info, encoded.len(), is_null)?
60 )?;
61
62 write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
63 Ok(())
64 }
65}
66
67impl Arguments for MssqlArguments {
68 type Database = Mssql;
69
70 fn reserve(&mut self, _additional: usize, _size: usize) {}
71
72 fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
73 where
74 T: Encode<'t, Self::Database> + Type<Self::Database>,
75 {
76 let type_info = value.produces().unwrap_or_else(T::type_info);
77 let mut encoded = Vec::with_capacity(value.size_hint());
78 let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
79 self.add_parameter(type_info, encoded, is_null)?;
80 Ok(())
81 }
82
83 fn len(&self) -> usize {
84 self.len
85 }
86
87 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
88 write!(writer, "@p{}", self.len)
89 }
90}
91
92pub(crate) fn write_parameter(
93 out: &mut Vec<u8>,
94 name: &str,
95 type_info: &MssqlTypeInfo,
96 encoded: &[u8],
97 is_null: bool,
98) -> Result<(), BoxDynError> {
99 write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
100}
101
102pub(crate) fn write_output_i32_parameter(
103 out: &mut Vec<u8>,
104 name: &str,
105 value: i32,
106) -> Result<(), BoxDynError> {
107 write_parameter_with_status(
108 out,
109 name,
110 STATUS_BY_REF_VALUE,
111 &MssqlTypeInfo::INT,
112 &value.to_le_bytes(),
113 false,
114 )
115}
116
117fn write_parameter_with_status(
118 out: &mut Vec<u8>,
119 name: &str,
120 status: u8,
121 type_info: &MssqlTypeInfo,
122 encoded: &[u8],
123 is_null: bool,
124) -> Result<(), BoxDynError> {
125 write_b_varchar(out, name)?;
126 out.push(status);
127 write_type_info(out, type_info, encoded.len(), is_null)?;
128 write_param_len_data(out, type_info, encoded, is_null)?;
129 Ok(())
130}
131
132pub(crate) fn write_nvarchar_parameter(
133 out: &mut Vec<u8>,
134 name: &str,
135 value: &str,
136) -> Result<(), BoxDynError> {
137 let mut encoded = Vec::with_capacity(value.len() * 2);
138 write_utf16(&mut encoded, value);
139 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
140}
141
142pub(crate) fn write_null_nvarchar_parameter(
143 out: &mut Vec<u8>,
144 name: &str,
145) -> Result<(), BoxDynError> {
146 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
147}
148
149pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
150 Ok(match type_info.kind() {
151 MssqlType::Bit => "bit",
152 MssqlType::TinyInt => "tinyint",
153 MssqlType::SmallInt => "smallint",
154 MssqlType::Int => "int",
155 MssqlType::BigInt => "bigint",
156 MssqlType::Real => "real",
157 MssqlType::Float => "float",
158 MssqlType::NVarChar => "nvarchar(max)",
159 MssqlType::VarChar => "varchar(max)",
160 MssqlType::VarBinary => "varbinary(max)",
161 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
162 })
163}
164
165fn write_type_info(
166 out: &mut Vec<u8>,
167 type_info: &MssqlTypeInfo,
168 encoded_len: usize,
169 is_null: bool,
170) -> Result<(), BoxDynError> {
171 match type_info.kind() {
172 MssqlType::Bit => {
173 out.push(DATA_TYPE_BITN);
174 out.push(1);
175 }
176 MssqlType::TinyInt => {
177 out.push(DATA_TYPE_INTN);
178 out.push(1);
179 }
180 MssqlType::SmallInt => {
181 out.push(DATA_TYPE_INTN);
182 out.push(2);
183 }
184 MssqlType::Int => {
185 out.push(DATA_TYPE_INTN);
186 out.push(4);
187 }
188 MssqlType::BigInt => {
189 out.push(DATA_TYPE_INTN);
190 out.push(8);
191 }
192 MssqlType::Real => {
193 out.push(DATA_TYPE_FLOATN);
194 out.push(4);
195 }
196 MssqlType::Float => {
197 out.push(DATA_TYPE_FLOATN);
198 out.push(8);
199 }
200 MssqlType::NVarChar => {
201 out.push(DATA_TYPE_NVARCHAR);
202 out.extend_from_slice(
203 &nvarchar_type_size(type_info, encoded_len, is_null)?.to_le_bytes(),
204 );
205 out.extend_from_slice(&DEFAULT_COLLATION);
206 }
207 MssqlType::VarChar => {
208 out.push(DATA_TYPE_BIGVARCHAR);
209 out.extend_from_slice(
210 &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
211 );
212 out.extend_from_slice(&DEFAULT_COLLATION);
213 }
214 MssqlType::VarBinary => {
215 out.push(DATA_TYPE_BIGVARBINARY);
216 out.extend_from_slice(
217 &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
218 );
219 }
220 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
221 }
222
223 Ok(())
224}
225
226fn write_param_len_data(
227 out: &mut Vec<u8>,
228 type_info: &MssqlTypeInfo,
229 encoded: &[u8],
230 is_null: bool,
231) -> Result<(), BoxDynError> {
232 match type_info.kind() {
233 MssqlType::Bit
234 | MssqlType::TinyInt
235 | MssqlType::SmallInt
236 | MssqlType::Int
237 | MssqlType::BigInt
238 | MssqlType::Real
239 | MssqlType::Float => {
240 out.push(if is_null {
241 0
242 } else {
243 u8::try_from(encoded.len())?
244 });
245 }
246 MssqlType::NVarChar | MssqlType::VarChar | MssqlType::VarBinary => {
247 if type_info.size() == Some(u16::MAX) {
248 write_plp_value(out, encoded, is_null)?;
249 } else {
250 let len = if is_null {
251 u16::MAX
252 } else {
253 u16::try_from(encoded.len())?
254 };
255 out.extend_from_slice(&len.to_le_bytes());
256 }
257 }
258 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
259 }
260
261 if !is_null && type_info.size() != Some(u16::MAX) {
262 out.extend_from_slice(encoded);
263 }
264
265 Ok(())
266}
267
268fn declaration(
269 type_info: &MssqlTypeInfo,
270 encoded_len: usize,
271 is_null: bool,
272) -> Result<String, BoxDynError> {
273 Ok(match type_info.kind() {
274 MssqlType::Bit => "bit".to_owned(),
275 MssqlType::TinyInt => "tinyint".to_owned(),
276 MssqlType::SmallInt => "smallint".to_owned(),
277 MssqlType::Int => "int".to_owned(),
278 MssqlType::BigInt => "bigint".to_owned(),
279 MssqlType::Real => "real".to_owned(),
280 MssqlType::Float => "float".to_owned(),
281 MssqlType::NVarChar => nvarchar_declaration(type_info, encoded_len, is_null)?,
282 MssqlType::VarChar => varchar_declaration(type_info, encoded_len, is_null)?,
283 MssqlType::VarBinary => varbinary_declaration(type_info, encoded_len, is_null)?,
284 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
285 })
286}
287
288fn nvarchar_declaration(
289 type_info: &MssqlTypeInfo,
290 encoded_len: usize,
291 is_null: bool,
292) -> Result<String, BoxDynError> {
293 let size = nvarchar_type_size(type_info, encoded_len, is_null)?;
294 if size == u16::MAX {
295 Ok("nvarchar(max)".to_owned())
296 } else {
297 Ok(format!("nvarchar({})", size / 2))
298 }
299}
300
301fn varchar_declaration(
302 type_info: &MssqlTypeInfo,
303 encoded_len: usize,
304 is_null: bool,
305) -> Result<String, BoxDynError> {
306 let size = bounded_short_len(type_info, encoded_len, is_null)?;
307 if size == u16::MAX {
308 Ok("varchar(max)".to_owned())
309 } else {
310 Ok(format!("varchar({size})"))
311 }
312}
313
314fn varbinary_declaration(
315 type_info: &MssqlTypeInfo,
316 encoded_len: usize,
317 is_null: bool,
318) -> Result<String, BoxDynError> {
319 let size = bounded_short_len(type_info, encoded_len, is_null)?;
320 if size == u16::MAX {
321 Ok("varbinary(max)".to_owned())
322 } else {
323 Ok(format!("varbinary({size})"))
324 }
325}
326
327fn nvarchar_type_size(
328 type_info: &MssqlTypeInfo,
329 encoded_len: usize,
330 is_null: bool,
331) -> Result<u16, BoxDynError> {
332 if let Some(size) = type_info.size() {
333 return Ok(size);
334 }
335
336 let len = if is_null {
337 2
338 } else {
339 std::cmp::max(2, encoded_len)
340 };
341 Ok(u16::try_from(len)?)
342}
343
344fn bounded_short_len(
345 type_info: &MssqlTypeInfo,
346 encoded_len: usize,
347 is_null: bool,
348) -> Result<u16, BoxDynError> {
349 if let Some(size) = type_info.size() {
350 return Ok(size);
351 }
352
353 let len = if is_null {
354 1
355 } else {
356 std::cmp::max(1, encoded_len)
357 };
358 Ok(u16::try_from(len)?)
359}
360
361fn write_plp_value(out: &mut Vec<u8>, encoded: &[u8], is_null: bool) -> Result<(), BoxDynError> {
362 if is_null {
363 out.extend_from_slice(&PLP_NULL.to_le_bytes());
364 return Ok(());
365 }
366
367 out.extend_from_slice(&u64::try_from(encoded.len())?.to_le_bytes());
368
369 for chunk in encoded.chunks(PLP_CHUNK_SIZE) {
370 out.extend_from_slice(&u32::try_from(chunk.len())?.to_le_bytes());
371 out.extend_from_slice(chunk);
372 }
373
374 out.extend_from_slice(&0_u32.to_le_bytes());
375 Ok(())
376}
377
378fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
379 let char_len = value.encode_utf16().count();
380 out.push(u8::try_from(char_len)?);
381 write_utf16(out, value);
382 Ok(())
383}
384
385fn write_utf16(out: &mut Vec<u8>, value: &str) {
386 for unit in value.encode_utf16() {
387 out.extend_from_slice(&unit.to_le_bytes());
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn formats_sql_server_style_placeholders() {
397 let args = MssqlArguments {
398 len: 3,
399 data: Vec::new(),
400 declarations: String::new(),
401 };
402 let mut out = String::new();
403
404 args.format_placeholder(&mut out).unwrap();
405
406 assert_eq!("@p3", out);
407 }
408
409 #[test]
410 fn records_declarations_and_rpc_argument_data() {
411 let mut args = MssqlArguments::default();
412
413 args.add(7_i32).unwrap();
414 args.add("hi").unwrap();
415
416 assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
417 assert!(args
418 .data()
419 .windows(2)
420 .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
421 assert!(args
422 .data()
423 .windows(8)
424 .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
425 }
426
427 #[test]
428 fn declares_lossless_integer_parameter_types() {
429 let mut args = MssqlArguments::default();
430
431 args.add(-5_i8).unwrap();
432 args.add(255_u8).unwrap();
433 args.add(65_535_u16).unwrap();
434 args.add(u32::MAX).unwrap();
435
436 assert_eq!(
437 "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
438 args.declarations()
439 );
440 assert!(args
441 .data()
442 .windows(2)
443 .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
444 assert!(args
445 .data()
446 .windows(2)
447 .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
448 }
449
450 #[test]
451 fn declares_large_text_and_binary_parameters_as_max() {
452 let mut args = MssqlArguments::default();
453 let text = "x".repeat(4001);
454 let bytes = vec![0x5a; 8001];
455
456 args.add(text.as_str()).unwrap();
457 args.add(bytes.as_slice()).unwrap();
458
459 assert_eq!("@p1 nvarchar(max),@p2 varbinary(max)", args.declarations());
460 assert!(args
461 .data()
462 .windows(3)
463 .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 0xff, 0xff]));
464 assert!(args
465 .data()
466 .windows(3)
467 .any(|bytes| bytes == [DATA_TYPE_BIGVARBINARY, 0xff, 0xff]));
468 assert!(args
469 .data()
470 .windows(8)
471 .any(|bytes| bytes == 8002_u64.to_le_bytes()));
472 assert!(args
473 .data()
474 .windows(8)
475 .any(|bytes| bytes == 8001_u64.to_le_bytes()));
476 }
477}