1use std::fmt::{self, Write};
2
3use sqlx_core::arguments::Arguments;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7
8use crate::{Mssql, MssqlType, MssqlTypeInfo};
9
10const DATA_TYPE_INTN: u8 = 0x26;
11const DATA_TYPE_BITN: u8 = 0x68;
12const DATA_TYPE_FLOATN: u8 = 0x6d;
13const DATA_TYPE_GUID: u8 = 0x24;
14const DATA_TYPE_NUMERICN: u8 = 0x6c;
15const DATA_TYPE_DATEN: u8 = 0x28;
16const DATA_TYPE_TIMEN: u8 = 0x29;
17const DATA_TYPE_DATETIME2N: u8 = 0x2a;
18const DATA_TYPE_DATETIMEOFFSETN: u8 = 0x2b;
19const DATA_TYPE_BIGVARBINARY: u8 = 0xa5;
20const DATA_TYPE_BIGVARCHAR: u8 = 0xa7;
21const DATA_TYPE_NVARCHAR: u8 = 0xe7;
22const DEFAULT_COLLATION: [u8; 5] = [0x81, 0x04, 0xd0, 0x00, 0x34];
23const PLP_NULL: u64 = 0xffff_ffff_ffff_ffff;
24const PLP_CHUNK_SIZE: usize = 8192;
25const STATUS_BY_REF_VALUE: u8 = 0x01;
26
27#[derive(Debug, Default, Clone)]
29pub struct MssqlArguments {
30 len: usize,
31 data: Vec<u8>,
32 declarations: String,
33}
34
35impl MssqlArguments {
36 pub const fn is_empty(&self) -> bool {
38 self.len == 0
39 }
40
41 pub(crate) fn data(&self) -> &[u8] {
42 &self.data
43 }
44
45 pub(crate) fn declarations(&self) -> &str {
46 &self.declarations
47 }
48
49 fn add_parameter(
50 &mut self,
51 type_info: MssqlTypeInfo,
52 encoded: Vec<u8>,
53 is_null: bool,
54 ) -> Result<(), BoxDynError> {
55 self.len += 1;
56 let name = format!("@p{}", self.len);
57
58 if !self.declarations.is_empty() {
59 self.declarations.push(',');
60 }
61
62 write!(
63 self.declarations,
64 "{name} {}",
65 declaration(&type_info, encoded.len(), is_null)?
66 )?;
67
68 write_parameter(&mut self.data, &name, &type_info, &encoded, is_null)?;
69 Ok(())
70 }
71}
72
73impl Arguments for MssqlArguments {
74 type Database = Mssql;
75
76 fn reserve(&mut self, _additional: usize, _size: usize) {}
77
78 fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
79 where
80 T: Encode<'t, Self::Database> + Type<Self::Database>,
81 {
82 let type_info = value.produces().unwrap_or_else(T::type_info);
83 let mut encoded = Vec::with_capacity(value.size_hint());
84 let is_null = matches!(value.encode(&mut encoded)?, IsNull::Yes);
85 self.add_parameter(type_info, encoded, is_null)?;
86 Ok(())
87 }
88
89 fn len(&self) -> usize {
90 self.len
91 }
92
93 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
94 write!(writer, "@p{}", self.len)
95 }
96}
97
98pub(crate) fn write_parameter(
99 out: &mut Vec<u8>,
100 name: &str,
101 type_info: &MssqlTypeInfo,
102 encoded: &[u8],
103 is_null: bool,
104) -> Result<(), BoxDynError> {
105 write_parameter_with_status(out, name, 0, type_info, encoded, is_null)
106}
107
108pub(crate) fn write_output_i32_parameter(
109 out: &mut Vec<u8>,
110 name: &str,
111 value: i32,
112) -> Result<(), BoxDynError> {
113 write_parameter_with_status(
114 out,
115 name,
116 STATUS_BY_REF_VALUE,
117 &MssqlTypeInfo::INT,
118 &value.to_le_bytes(),
119 false,
120 )
121}
122
123fn write_parameter_with_status(
124 out: &mut Vec<u8>,
125 name: &str,
126 status: u8,
127 type_info: &MssqlTypeInfo,
128 encoded: &[u8],
129 is_null: bool,
130) -> Result<(), BoxDynError> {
131 write_b_varchar(out, name)?;
132 out.push(status);
133 write_type_info(out, type_info, encoded.len(), is_null)?;
134 write_param_len_data(out, type_info, encoded, is_null)?;
135 Ok(())
136}
137
138pub(crate) fn write_nvarchar_parameter(
139 out: &mut Vec<u8>,
140 name: &str,
141 value: &str,
142) -> Result<(), BoxDynError> {
143 let mut encoded = Vec::with_capacity(value.len() * 2);
144 write_utf16(&mut encoded, value);
145 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &encoded, false)
146}
147
148pub(crate) fn write_null_nvarchar_parameter(
149 out: &mut Vec<u8>,
150 name: &str,
151) -> Result<(), BoxDynError> {
152 write_parameter(out, name, &MssqlTypeInfo::NVARCHAR, &[], true)
153}
154
155pub(crate) fn type_declaration(type_info: &MssqlTypeInfo) -> Result<&'static str, BoxDynError> {
156 Ok(match type_info.kind() {
157 MssqlType::Bit => "bit",
158 MssqlType::TinyInt => "tinyint",
159 MssqlType::SmallInt => "smallint",
160 MssqlType::Int => "int",
161 MssqlType::BigInt => "bigint",
162 MssqlType::Real => "real",
163 MssqlType::Float => "float",
164 MssqlType::NVarChar => "nvarchar(max)",
165 MssqlType::VarChar => "varchar(max)",
166 MssqlType::VarBinary => "varbinary(max)",
167 MssqlType::Decimal => "decimal(38, 0)",
168 MssqlType::Date => "date",
169 MssqlType::Time => "time(7)",
170 MssqlType::DateTime2 => "datetime2(7)",
171 MssqlType::DateTimeOffset => "datetimeoffset(7)",
172 MssqlType::UniqueIdentifier => "uniqueidentifier",
173 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
174 })
175}
176
177fn write_type_info(
178 out: &mut Vec<u8>,
179 type_info: &MssqlTypeInfo,
180 encoded_len: usize,
181 is_null: bool,
182) -> Result<(), BoxDynError> {
183 match type_info.kind() {
184 MssqlType::Bit => {
185 out.push(DATA_TYPE_BITN);
186 out.push(1);
187 }
188 MssqlType::TinyInt => {
189 out.push(DATA_TYPE_INTN);
190 out.push(1);
191 }
192 MssqlType::SmallInt => {
193 out.push(DATA_TYPE_INTN);
194 out.push(2);
195 }
196 MssqlType::Int => {
197 out.push(DATA_TYPE_INTN);
198 out.push(4);
199 }
200 MssqlType::BigInt => {
201 out.push(DATA_TYPE_INTN);
202 out.push(8);
203 }
204 MssqlType::Real => {
205 out.push(DATA_TYPE_FLOATN);
206 out.push(4);
207 }
208 MssqlType::Float => {
209 out.push(DATA_TYPE_FLOATN);
210 out.push(8);
211 }
212 MssqlType::NVarChar => {
213 out.push(DATA_TYPE_NVARCHAR);
214 out.extend_from_slice(
215 &nvarchar_type_size(type_info, encoded_len, is_null)?.to_le_bytes(),
216 );
217 out.extend_from_slice(&DEFAULT_COLLATION);
218 }
219 MssqlType::VarChar => {
220 out.push(DATA_TYPE_BIGVARCHAR);
221 out.extend_from_slice(
222 &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
223 );
224 out.extend_from_slice(&DEFAULT_COLLATION);
225 }
226 MssqlType::VarBinary => {
227 out.push(DATA_TYPE_BIGVARBINARY);
228 out.extend_from_slice(
229 &bounded_short_len(type_info, encoded_len, is_null)?.to_le_bytes(),
230 );
231 }
232 MssqlType::Decimal => {
233 out.push(DATA_TYPE_NUMERICN);
234 out.push(u8::try_from(type_info.size().unwrap_or(17))?);
235 out.push(type_info.precision().max(1));
236 out.push(type_info.scale());
237 }
238 MssqlType::Date => {
239 out.push(DATA_TYPE_DATEN);
240 }
241 MssqlType::Time => {
242 out.push(DATA_TYPE_TIMEN);
243 out.push(type_info.scale());
244 }
245 MssqlType::DateTime2 => {
246 out.push(DATA_TYPE_DATETIME2N);
247 out.push(type_info.scale());
248 }
249 MssqlType::DateTimeOffset => {
250 out.push(DATA_TYPE_DATETIMEOFFSETN);
251 out.push(type_info.scale());
252 }
253 MssqlType::UniqueIdentifier => {
254 out.push(DATA_TYPE_GUID);
255 out.push(16);
256 }
257 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
258 }
259
260 Ok(())
261}
262
263fn write_param_len_data(
264 out: &mut Vec<u8>,
265 type_info: &MssqlTypeInfo,
266 encoded: &[u8],
267 is_null: bool,
268) -> Result<(), BoxDynError> {
269 match type_info.kind() {
270 MssqlType::Bit
271 | MssqlType::TinyInt
272 | MssqlType::SmallInt
273 | MssqlType::Int
274 | MssqlType::BigInt
275 | MssqlType::Real
276 | MssqlType::Float
277 | MssqlType::Decimal
278 | MssqlType::Date
279 | MssqlType::Time
280 | MssqlType::DateTime2
281 | MssqlType::DateTimeOffset
282 | MssqlType::UniqueIdentifier => {
283 out.push(if is_null {
284 0
285 } else {
286 u8::try_from(encoded.len())?
287 });
288 }
289 MssqlType::NVarChar | MssqlType::VarChar | MssqlType::VarBinary => {
290 if type_info.size() == Some(u16::MAX) {
291 write_plp_value(out, encoded, is_null)?;
292 } else {
293 let len = if is_null {
294 u16::MAX
295 } else {
296 u16::try_from(encoded.len())?
297 };
298 out.extend_from_slice(&len.to_le_bytes());
299 }
300 }
301 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
302 }
303
304 if !is_null && type_info.size() != Some(u16::MAX) {
305 out.extend_from_slice(encoded);
306 }
307
308 Ok(())
309}
310
311fn declaration(
312 type_info: &MssqlTypeInfo,
313 encoded_len: usize,
314 is_null: bool,
315) -> Result<String, BoxDynError> {
316 Ok(match type_info.kind() {
317 MssqlType::Bit => "bit".to_owned(),
318 MssqlType::TinyInt => "tinyint".to_owned(),
319 MssqlType::SmallInt => "smallint".to_owned(),
320 MssqlType::Int => "int".to_owned(),
321 MssqlType::BigInt => "bigint".to_owned(),
322 MssqlType::Real => "real".to_owned(),
323 MssqlType::Float => "float".to_owned(),
324 MssqlType::NVarChar => nvarchar_declaration(type_info, encoded_len, is_null)?,
325 MssqlType::VarChar => varchar_declaration(type_info, encoded_len, is_null)?,
326 MssqlType::VarBinary => varbinary_declaration(type_info, encoded_len, is_null)?,
327 MssqlType::Decimal => format!(
328 "decimal({},{})",
329 type_info.precision().max(1),
330 type_info.scale()
331 ),
332 MssqlType::Date => "date".to_owned(),
333 MssqlType::Time => format!("time({})", type_info.scale()),
334 MssqlType::DateTime2 => format!("datetime2({})", type_info.scale()),
335 MssqlType::DateTimeOffset => format!("datetimeoffset({})", type_info.scale()),
336 MssqlType::UniqueIdentifier => "uniqueidentifier".to_owned(),
337 other => return Err(format!("SQL Server arguments do not support type {other:?}").into()),
338 })
339}
340
341fn nvarchar_declaration(
342 type_info: &MssqlTypeInfo,
343 encoded_len: usize,
344 is_null: bool,
345) -> Result<String, BoxDynError> {
346 let size = nvarchar_type_size(type_info, encoded_len, is_null)?;
347 if size == u16::MAX {
348 Ok("nvarchar(max)".to_owned())
349 } else {
350 Ok(format!("nvarchar({})", size / 2))
351 }
352}
353
354fn varchar_declaration(
355 type_info: &MssqlTypeInfo,
356 encoded_len: usize,
357 is_null: bool,
358) -> Result<String, BoxDynError> {
359 let size = bounded_short_len(type_info, encoded_len, is_null)?;
360 if size == u16::MAX {
361 Ok("varchar(max)".to_owned())
362 } else {
363 Ok(format!("varchar({size})"))
364 }
365}
366
367fn varbinary_declaration(
368 type_info: &MssqlTypeInfo,
369 encoded_len: usize,
370 is_null: bool,
371) -> Result<String, BoxDynError> {
372 let size = bounded_short_len(type_info, encoded_len, is_null)?;
373 if size == u16::MAX {
374 Ok("varbinary(max)".to_owned())
375 } else {
376 Ok(format!("varbinary({size})"))
377 }
378}
379
380fn nvarchar_type_size(
381 type_info: &MssqlTypeInfo,
382 encoded_len: usize,
383 is_null: bool,
384) -> Result<u16, BoxDynError> {
385 if let Some(size) = type_info.size() {
386 return Ok(size);
387 }
388
389 let len = if is_null {
390 2
391 } else {
392 std::cmp::max(2, encoded_len)
393 };
394 Ok(u16::try_from(len)?)
395}
396
397fn bounded_short_len(
398 type_info: &MssqlTypeInfo,
399 encoded_len: usize,
400 is_null: bool,
401) -> Result<u16, BoxDynError> {
402 if let Some(size) = type_info.size() {
403 return Ok(size);
404 }
405
406 let len = if is_null {
407 1
408 } else {
409 std::cmp::max(1, encoded_len)
410 };
411 Ok(u16::try_from(len)?)
412}
413
414fn write_plp_value(out: &mut Vec<u8>, encoded: &[u8], is_null: bool) -> Result<(), BoxDynError> {
415 if is_null {
416 out.extend_from_slice(&PLP_NULL.to_le_bytes());
417 return Ok(());
418 }
419
420 out.extend_from_slice(&u64::try_from(encoded.len())?.to_le_bytes());
421
422 for chunk in encoded.chunks(PLP_CHUNK_SIZE) {
423 out.extend_from_slice(&u32::try_from(chunk.len())?.to_le_bytes());
424 out.extend_from_slice(chunk);
425 }
426
427 out.extend_from_slice(&0_u32.to_le_bytes());
428 Ok(())
429}
430
431fn write_b_varchar(out: &mut Vec<u8>, value: &str) -> Result<(), BoxDynError> {
432 let char_len = value.encode_utf16().count();
433 out.push(u8::try_from(char_len)?);
434 write_utf16(out, value);
435 Ok(())
436}
437
438fn write_utf16(out: &mut Vec<u8>, value: &str) {
439 for unit in value.encode_utf16() {
440 out.extend_from_slice(&unit.to_le_bytes());
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn formats_sql_server_style_placeholders() {
450 let args = MssqlArguments {
451 len: 3,
452 data: Vec::new(),
453 declarations: String::new(),
454 };
455 let mut out = String::new();
456
457 args.format_placeholder(&mut out).unwrap();
458
459 assert_eq!("@p3", out);
460 }
461
462 #[test]
463 fn records_declarations_and_rpc_argument_data() {
464 let mut args = MssqlArguments::default();
465
466 args.add(7_i32).unwrap();
467 args.add("hi").unwrap();
468
469 assert_eq!("@p1 int,@p2 nvarchar(2)", args.declarations());
470 assert!(args
471 .data()
472 .windows(2)
473 .any(|bytes| bytes == [DATA_TYPE_INTN, 4]));
474 assert!(args
475 .data()
476 .windows(8)
477 .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 4, 0, 0x81, 0x04, 0xd0, 0x00, 0x34]));
478 }
479
480 #[test]
481 fn declares_lossless_integer_parameter_types() {
482 let mut args = MssqlArguments::default();
483
484 args.add(-5_i8).unwrap();
485 args.add(255_u8).unwrap();
486 args.add(65_535_u16).unwrap();
487 args.add(u32::MAX).unwrap();
488
489 assert_eq!(
490 "@p1 smallint,@p2 tinyint,@p3 int,@p4 bigint",
491 args.declarations()
492 );
493 assert!(args
494 .data()
495 .windows(2)
496 .any(|bytes| bytes == [DATA_TYPE_INTN, 1]));
497 assert!(args
498 .data()
499 .windows(2)
500 .any(|bytes| bytes == [DATA_TYPE_INTN, 8]));
501 }
502
503 #[test]
504 fn declares_large_text_and_binary_parameters_as_max() {
505 let mut args = MssqlArguments::default();
506 let text = "x".repeat(4001);
507 let bytes = vec![0x5a; 8001];
508
509 args.add(text.as_str()).unwrap();
510 args.add(bytes.as_slice()).unwrap();
511
512 assert_eq!("@p1 nvarchar(max),@p2 varbinary(max)", args.declarations());
513 assert!(args
514 .data()
515 .windows(3)
516 .any(|bytes| bytes == [DATA_TYPE_NVARCHAR, 0xff, 0xff]));
517 assert!(args
518 .data()
519 .windows(3)
520 .any(|bytes| bytes == [DATA_TYPE_BIGVARBINARY, 0xff, 0xff]));
521 assert!(args
522 .data()
523 .windows(8)
524 .any(|bytes| bytes == 8002_u64.to_le_bytes()));
525 assert!(args
526 .data()
527 .windows(8)
528 .any(|bytes| bytes == 8001_u64.to_le_bytes()));
529 }
530}