1use std::borrow::Cow;
2
3use sqlx_core::decode::Decode;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7use sqlx_core::value::{Value, ValueRef};
8
9use crate::{Mssql, MssqlType, MssqlTypeInfo};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MssqlValue {
14 type_info: MssqlTypeInfo,
15 data: Option<Vec<u8>>,
16}
17
18impl MssqlValue {
19 pub(crate) fn new(type_info: MssqlTypeInfo, data: Option<Vec<u8>>) -> Self {
21 Self { type_info, data }
22 }
23
24 pub fn null(type_info: MssqlTypeInfo) -> Self {
26 Self {
27 type_info,
28 data: None,
29 }
30 }
31}
32
33impl Value for MssqlValue {
34 type Database = Mssql;
35
36 fn as_ref(&self) -> MssqlValueRef<'_> {
37 MssqlValueRef {
38 type_info: &self.type_info,
39 data: self.data.as_deref(),
40 }
41 }
42
43 fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
44 Cow::Borrowed(&self.type_info)
45 }
46
47 fn is_null(&self) -> bool {
48 self.data.is_none()
49 }
50}
51
52#[derive(Debug, Clone, Copy)]
54pub struct MssqlValueRef<'r> {
55 type_info: &'r MssqlTypeInfo,
56 data: Option<&'r [u8]>,
57}
58
59impl<'r> ValueRef<'r> for MssqlValueRef<'r> {
60 type Database = Mssql;
61
62 fn to_owned(&self) -> MssqlValue {
63 MssqlValue {
64 type_info: self.type_info.clone(),
65 data: self.data.map(ToOwned::to_owned),
66 }
67 }
68
69 fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
70 Cow::Borrowed(self.type_info)
71 }
72
73 fn is_null(&self) -> bool {
74 self.data.is_none()
75 }
76}
77
78impl<'r> MssqlValueRef<'r> {
79 pub(crate) fn as_bytes(&self) -> Option<&'r [u8]> {
80 self.data
81 }
82}
83
84fn non_null_bytes<'r>(value: MssqlValueRef<'r>, rust_type: &str) -> Result<&'r [u8], BoxDynError> {
85 value
86 .as_bytes()
87 .ok_or_else(|| format!("cannot decode SQL Server NULL as {rust_type}").into())
88}
89
90fn decode_integer(value: MssqlValueRef<'_>, rust_type: &str) -> Result<i64, BoxDynError> {
91 let bytes = non_null_bytes(value, rust_type)?;
92
93 match bytes.len() {
94 1 => Ok(i64::from(bytes[0])),
95 2 => Ok(i64::from(i16::from_le_bytes([bytes[0], bytes[1]]))),
96 4 => Ok(i64::from(i32::from_le_bytes([
97 bytes[0], bytes[1], bytes[2], bytes[3],
98 ]))),
99 8 => Ok(i64::from_le_bytes([
100 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
101 ])),
102 len => Err(format!("cannot decode {len}-byte SQL Server integer as {rust_type}").into()),
103 }
104}
105
106impl Type<Mssql> for i8 {
107 fn type_info() -> MssqlTypeInfo {
108 MssqlTypeInfo::SMALLINT
109 }
110
111 fn compatible(ty: &MssqlTypeInfo) -> bool {
112 matches!(
113 ty.kind(),
114 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
115 )
116 }
117}
118
119impl Encode<'_, Mssql> for i8 {
120 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
121 <i16 as Encode<Mssql>>::encode_by_ref(&i16::from(*self), buf)
122 }
123}
124
125impl Decode<'_, Mssql> for i8 {
126 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
127 Ok(i8::try_from(decode_integer(value, "i8")?)?)
128 }
129}
130
131impl Type<Mssql> for u8 {
132 fn type_info() -> MssqlTypeInfo {
133 MssqlTypeInfo::TINYINT
134 }
135
136 fn compatible(ty: &MssqlTypeInfo) -> bool {
137 matches!(
138 ty.kind(),
139 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
140 )
141 }
142}
143
144impl Encode<'_, Mssql> for u8 {
145 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
146 buf.push(*self);
147 Ok(IsNull::No)
148 }
149}
150
151impl Decode<'_, Mssql> for u8 {
152 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
153 Ok(u8::try_from(decode_integer(value, "u8")?)?)
154 }
155}
156
157impl Type<Mssql> for i32 {
158 fn type_info() -> MssqlTypeInfo {
159 MssqlTypeInfo::INT
160 }
161
162 fn compatible(ty: &MssqlTypeInfo) -> bool {
163 matches!(
164 ty.kind(),
165 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int
166 )
167 }
168}
169
170impl Encode<'_, Mssql> for i32 {
171 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
172 buf.extend_from_slice(&self.to_le_bytes());
173 Ok(IsNull::No)
174 }
175}
176
177impl Type<Mssql> for bool {
178 fn type_info() -> MssqlTypeInfo {
179 MssqlTypeInfo::BIT
180 }
181
182 fn compatible(ty: &MssqlTypeInfo) -> bool {
183 matches!(ty.kind(), MssqlType::Bit)
184 }
185}
186
187impl Encode<'_, Mssql> for bool {
188 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
189 buf.push(u8::from(*self));
190 Ok(IsNull::No)
191 }
192}
193
194impl Decode<'_, Mssql> for bool {
195 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
196 let bytes = value
197 .as_bytes()
198 .ok_or_else(|| "cannot decode SQL Server NULL as bool".to_owned())?;
199
200 match bytes {
201 [0] => Ok(false),
202 [1] => Ok(true),
203 _ => Err("cannot decode SQL Server bit as bool".into()),
204 }
205 }
206}
207
208impl Type<Mssql> for i16 {
209 fn type_info() -> MssqlTypeInfo {
210 MssqlTypeInfo::SMALLINT
211 }
212
213 fn compatible(ty: &MssqlTypeInfo) -> bool {
214 matches!(ty.kind(), MssqlType::TinyInt | MssqlType::SmallInt)
215 }
216}
217
218impl Encode<'_, Mssql> for i16 {
219 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
220 buf.extend_from_slice(&self.to_le_bytes());
221 Ok(IsNull::No)
222 }
223}
224
225impl Decode<'_, Mssql> for i16 {
226 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
227 Ok(i16::try_from(decode_integer(value, "i16")?)?)
228 }
229}
230
231impl Type<Mssql> for u16 {
232 fn type_info() -> MssqlTypeInfo {
233 MssqlTypeInfo::INT
234 }
235
236 fn compatible(ty: &MssqlTypeInfo) -> bool {
237 matches!(
238 ty.kind(),
239 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
240 )
241 }
242}
243
244impl Encode<'_, Mssql> for u16 {
245 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
246 <i32 as Encode<Mssql>>::encode_by_ref(&i32::from(*self), buf)
247 }
248}
249
250impl Decode<'_, Mssql> for u16 {
251 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
252 Ok(u16::try_from(decode_integer(value, "u16")?)?)
253 }
254}
255
256impl Decode<'_, Mssql> for i32 {
257 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
258 Ok(i32::try_from(decode_integer(value, "i32")?)?)
259 }
260}
261
262impl Type<Mssql> for u32 {
263 fn type_info() -> MssqlTypeInfo {
264 MssqlTypeInfo::BIGINT
265 }
266
267 fn compatible(ty: &MssqlTypeInfo) -> bool {
268 matches!(
269 ty.kind(),
270 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
271 )
272 }
273}
274
275impl Encode<'_, Mssql> for u32 {
276 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
277 <i64 as Encode<Mssql>>::encode_by_ref(&i64::from(*self), buf)
278 }
279}
280
281impl Decode<'_, Mssql> for u32 {
282 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
283 Ok(u32::try_from(decode_integer(value, "u32")?)?)
284 }
285}
286
287impl Type<Mssql> for i64 {
288 fn type_info() -> MssqlTypeInfo {
289 MssqlTypeInfo::BIGINT
290 }
291
292 fn compatible(ty: &MssqlTypeInfo) -> bool {
293 matches!(
294 ty.kind(),
295 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
296 )
297 }
298}
299
300impl Encode<'_, Mssql> for i64 {
301 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
302 buf.extend_from_slice(&self.to_le_bytes());
303 Ok(IsNull::No)
304 }
305}
306
307impl Decode<'_, Mssql> for i64 {
308 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
309 decode_integer(value, "i64")
310 }
311}
312
313impl Type<Mssql> for f32 {
314 fn type_info() -> MssqlTypeInfo {
315 MssqlTypeInfo::REAL
316 }
317
318 fn compatible(ty: &MssqlTypeInfo) -> bool {
319 matches!(ty.kind(), MssqlType::Real)
320 }
321}
322
323impl Encode<'_, Mssql> for f32 {
324 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
325 buf.extend_from_slice(&self.to_le_bytes());
326 Ok(IsNull::No)
327 }
328}
329
330impl Decode<'_, Mssql> for f32 {
331 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
332 let bytes = value
333 .as_bytes()
334 .ok_or_else(|| "cannot decode SQL Server NULL as f32".to_owned())?;
335
336 match bytes {
337 [a, b, c, d] => Ok(f32::from_le_bytes([*a, *b, *c, *d])),
338 _ => Err("cannot decode SQL Server real as f32".into()),
339 }
340 }
341}
342
343impl Type<Mssql> for f64 {
344 fn type_info() -> MssqlTypeInfo {
345 MssqlTypeInfo::FLOAT
346 }
347
348 fn compatible(ty: &MssqlTypeInfo) -> bool {
349 matches!(ty.kind(), MssqlType::Real | MssqlType::Float)
350 }
351}
352
353impl Decode<'_, Mssql> for f64 {
354 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
355 match value
356 .as_bytes()
357 .ok_or_else(|| "cannot decode SQL Server NULL as f64".to_owned())?
358 {
359 [a, b, c, d] => Ok(f64::from(f32::from_le_bytes([*a, *b, *c, *d]))),
360 [a, b, c, d, e, f, g, h] => Ok(f64::from_le_bytes([*a, *b, *c, *d, *e, *f, *g, *h])),
361 _ => Err("cannot decode SQL Server float as f64".into()),
362 }
363 }
364}
365
366impl Encode<'_, Mssql> for f64 {
367 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
368 buf.extend_from_slice(&self.to_le_bytes());
369 Ok(IsNull::No)
370 }
371}
372
373impl Type<Mssql> for str {
374 fn type_info() -> MssqlTypeInfo {
375 MssqlTypeInfo::NVARCHAR
376 }
377
378 fn compatible(ty: &MssqlTypeInfo) -> bool {
379 matches!(ty.kind(), MssqlType::NVarChar | MssqlType::VarChar)
380 }
381}
382
383impl Type<Mssql> for String {
384 fn type_info() -> MssqlTypeInfo {
385 <str as Type<Mssql>>::type_info()
386 }
387
388 fn compatible(ty: &MssqlTypeInfo) -> bool {
389 <str as Type<Mssql>>::compatible(ty)
390 }
391}
392
393impl Encode<'_, Mssql> for str {
394 fn produces(&self) -> Option<MssqlTypeInfo> {
395 Some(MssqlTypeInfo::with_size(
396 MssqlType::NVarChar,
397 nvarchar_parameter_size(self),
398 ))
399 }
400
401 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
402 for unit in self.encode_utf16() {
403 buf.extend_from_slice(&unit.to_le_bytes());
404 }
405
406 Ok(IsNull::No)
407 }
408}
409
410impl<'q> Encode<'q, Mssql> for &'q str {
411 fn produces(&self) -> Option<MssqlTypeInfo> {
412 <str as Encode<Mssql>>::produces(*self)
413 }
414
415 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
416 <str as Encode<Mssql>>::encode_by_ref(*self, buf)
417 }
418}
419
420impl Encode<'_, Mssql> for String {
421 fn produces(&self) -> Option<MssqlTypeInfo> {
422 <str as Encode<Mssql>>::produces(self.as_str())
423 }
424
425 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
426 <str as Encode<Mssql>>::encode_by_ref(self.as_str(), buf)
427 }
428}
429
430impl Decode<'_, Mssql> for String {
431 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
432 let bytes = value
433 .as_bytes()
434 .ok_or_else(|| "cannot decode SQL Server NULL as String".to_owned())?;
435
436 if matches!(value.type_info.kind(), MssqlType::VarChar) {
437 return Ok(std::str::from_utf8(bytes)?.to_owned());
438 }
439
440 if bytes.len() % 2 != 0 {
441 return Err("cannot decode odd-length SQL Server UTF-16 text".into());
442 }
443
444 let units = bytes
445 .chunks_exact(2)
446 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
447 .collect::<Vec<_>>();
448 Ok(String::from_utf16(&units)?)
449 }
450}
451
452impl Type<Mssql> for [u8] {
453 fn type_info() -> MssqlTypeInfo {
454 MssqlTypeInfo::VARBINARY
455 }
456
457 fn compatible(ty: &MssqlTypeInfo) -> bool {
458 matches!(ty.kind(), MssqlType::VarBinary)
459 }
460}
461
462impl Type<Mssql> for Vec<u8> {
463 fn type_info() -> MssqlTypeInfo {
464 <[u8] as Type<Mssql>>::type_info()
465 }
466
467 fn compatible(ty: &MssqlTypeInfo) -> bool {
468 <[u8] as Type<Mssql>>::compatible(ty)
469 }
470}
471
472impl Encode<'_, Mssql> for [u8] {
473 fn produces(&self) -> Option<MssqlTypeInfo> {
474 Some(MssqlTypeInfo::with_size(
475 MssqlType::VarBinary,
476 varbinary_parameter_size(self.len()),
477 ))
478 }
479
480 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
481 buf.extend_from_slice(self);
482 Ok(IsNull::No)
483 }
484}
485
486impl<'q> Encode<'q, Mssql> for &'q [u8] {
487 fn produces(&self) -> Option<MssqlTypeInfo> {
488 <[u8] as Encode<Mssql>>::produces(*self)
489 }
490
491 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
492 <[u8] as Encode<Mssql>>::encode_by_ref(*self, buf)
493 }
494}
495
496impl Encode<'_, Mssql> for Vec<u8> {
497 fn produces(&self) -> Option<MssqlTypeInfo> {
498 <[u8] as Encode<Mssql>>::produces(self.as_slice())
499 }
500
501 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
502 <[u8] as Encode<Mssql>>::encode_by_ref(self.as_slice(), buf)
503 }
504}
505
506impl Decode<'_, Mssql> for Vec<u8> {
507 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
508 Ok(<&[u8] as Decode<Mssql>>::decode(value)?.to_vec())
509 }
510}
511
512impl<'r> Decode<'r, Mssql> for &'r [u8] {
513 fn decode(value: MssqlValueRef<'r>) -> Result<Self, BoxDynError> {
514 non_null_bytes(value, "bytes")
515 }
516}
517
518fn nvarchar_parameter_size(value: &str) -> u16 {
519 let bytes = value.encode_utf16().count().saturating_mul(2);
520 if bytes > 8000 {
521 u16::MAX
522 } else {
523 u16::try_from(std::cmp::max(2, bytes)).unwrap_or(u16::MAX)
524 }
525}
526
527fn varbinary_parameter_size(len: usize) -> u16 {
528 if len > 8000 {
529 u16::MAX
530 } else {
531 u16::try_from(std::cmp::max(1, len)).unwrap_or(u16::MAX)
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn integer_scalars_use_lossless_parameter_types() {
541 assert_eq!(MssqlTypeInfo::SMALLINT, <i8 as Type<Mssql>>::type_info());
542 assert_eq!(MssqlTypeInfo::TINYINT, <u8 as Type<Mssql>>::type_info());
543 assert_eq!(MssqlTypeInfo::INT, <u16 as Type<Mssql>>::type_info());
544 assert_eq!(MssqlTypeInfo::BIGINT, <u32 as Type<Mssql>>::type_info());
545 }
546
547 #[test]
548 fn encodes_unsigned_integer_scalars_without_saturation() {
549 let mut buf = Vec::new();
550 let _ = <u32 as Encode<Mssql>>::encode_by_ref(&u32::MAX, &mut buf).unwrap();
551 assert_eq!(i64::from(u32::MAX).to_le_bytes(), buf.as_slice());
552
553 buf.clear();
554 let _ = <u16 as Encode<Mssql>>::encode_by_ref(&u16::MAX, &mut buf).unwrap();
555 assert_eq!(i32::from(u16::MAX).to_le_bytes(), buf.as_slice());
556 }
557
558 #[test]
559 fn decodes_integer_scalars_with_range_checks() {
560 let value = MssqlValue::new(MssqlTypeInfo::INT, Some(65_535_i32.to_le_bytes().to_vec()));
561 assert_eq!(
562 65_535_u16,
563 <u16 as Decode<Mssql>>::decode(value.as_ref()).unwrap()
564 );
565
566 let negative = MssqlValue::new(MssqlTypeInfo::INT, Some((-1_i32).to_le_bytes().to_vec()));
567 assert!(<u16 as Decode<Mssql>>::decode(negative.as_ref()).is_err());
568
569 let too_large = MssqlValue::new(MssqlTypeInfo::INT, Some(128_i32.to_le_bytes().to_vec()));
570 assert!(<i8 as Decode<Mssql>>::decode(too_large.as_ref()).is_err());
571 }
572
573 #[test]
574 fn decodes_borrowed_bytes() {
575 let value = MssqlValue::new(MssqlTypeInfo::VARBINARY, Some(vec![1, 2, 3, 4]));
576 let bytes = <&[u8] as Decode<Mssql>>::decode(value.as_ref()).unwrap();
577
578 assert_eq!(&[1, 2, 3, 4], bytes);
579 }
580
581 #[test]
582 fn decodes_ascii_varchar_as_utf8() {
583 let value = MssqlValue::new(MssqlTypeInfo::VARCHAR, Some(b"hello".to_vec()));
584
585 assert_eq!(
586 "hello",
587 <String as Decode<Mssql>>::decode(value.as_ref()).unwrap()
588 );
589 }
590
591 #[test]
592 fn decodes_nvarchar_as_utf16() {
593 let mut data = Vec::new();
594 for unit in "hello".encode_utf16() {
595 data.extend_from_slice(&unit.to_le_bytes());
596 }
597
598 let value = MssqlValue::new(MssqlTypeInfo::NVARCHAR, Some(data));
599
600 assert_eq!(
601 "hello",
602 <String as Decode<Mssql>>::decode(value.as_ref()).unwrap()
603 );
604 }
605}