1use std::borrow::Cow;
2
3use sqlx_core::decode::Decode;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7use sqlx_core::value::{Value, ValueRef};
8
9use crate::{Mssql, MssqlType, MssqlTypeInfo};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MssqlValue {
14 type_info: MssqlTypeInfo,
15 data: Option<Vec<u8>>,
16}
17
18impl MssqlValue {
19 pub(crate) fn new(type_info: MssqlTypeInfo, data: Option<Vec<u8>>) -> Self {
21 Self { type_info, data }
22 }
23
24 pub fn null(type_info: MssqlTypeInfo) -> Self {
26 Self {
27 type_info,
28 data: None,
29 }
30 }
31}
32
33impl Value for MssqlValue {
34 type Database = Mssql;
35
36 fn as_ref(&self) -> MssqlValueRef<'_> {
37 MssqlValueRef {
38 type_info: &self.type_info,
39 data: self.data.as_deref(),
40 }
41 }
42
43 fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
44 Cow::Borrowed(&self.type_info)
45 }
46
47 fn is_null(&self) -> bool {
48 self.data.is_none()
49 }
50}
51
52#[derive(Debug, Clone, Copy)]
54pub struct MssqlValueRef<'r> {
55 type_info: &'r MssqlTypeInfo,
56 data: Option<&'r [u8]>,
57}
58
59impl<'r> ValueRef<'r> for MssqlValueRef<'r> {
60 type Database = Mssql;
61
62 fn to_owned(&self) -> MssqlValue {
63 MssqlValue {
64 type_info: self.type_info.clone(),
65 data: self.data.map(ToOwned::to_owned),
66 }
67 }
68
69 fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
70 Cow::Borrowed(self.type_info)
71 }
72
73 fn is_null(&self) -> bool {
74 self.data.is_none()
75 }
76}
77
78impl<'r> MssqlValueRef<'r> {
79 pub(crate) fn as_bytes(&self) -> Option<&'r [u8]> {
80 self.data
81 }
82}
83
84fn non_null_bytes<'r>(value: MssqlValueRef<'r>, rust_type: &str) -> Result<&'r [u8], BoxDynError> {
85 value
86 .as_bytes()
87 .ok_or_else(|| format!("cannot decode SQL Server NULL as {rust_type}").into())
88}
89
90fn decode_integer(value: MssqlValueRef<'_>, rust_type: &str) -> Result<i64, BoxDynError> {
91 let bytes = non_null_bytes(value, rust_type)?;
92
93 match bytes.len() {
94 1 => Ok(i64::from(bytes[0])),
95 2 => Ok(i64::from(i16::from_le_bytes([bytes[0], bytes[1]]))),
96 4 => Ok(i64::from(i32::from_le_bytes([
97 bytes[0], bytes[1], bytes[2], bytes[3],
98 ]))),
99 8 => Ok(i64::from_le_bytes([
100 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
101 ])),
102 len => Err(format!("cannot decode {len}-byte SQL Server integer as {rust_type}").into()),
103 }
104}
105
106impl Type<Mssql> for i8 {
107 fn type_info() -> MssqlTypeInfo {
108 MssqlTypeInfo::SMALLINT
109 }
110
111 fn compatible(ty: &MssqlTypeInfo) -> bool {
112 matches!(
113 ty.kind(),
114 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
115 )
116 }
117}
118
119impl Encode<'_, Mssql> for i8 {
120 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
121 <i16 as Encode<Mssql>>::encode_by_ref(&i16::from(*self), buf)
122 }
123}
124
125impl Decode<'_, Mssql> for i8 {
126 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
127 Ok(i8::try_from(decode_integer(value, "i8")?)?)
128 }
129}
130
131impl Type<Mssql> for u8 {
132 fn type_info() -> MssqlTypeInfo {
133 MssqlTypeInfo::TINYINT
134 }
135
136 fn compatible(ty: &MssqlTypeInfo) -> bool {
137 matches!(
138 ty.kind(),
139 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
140 )
141 }
142}
143
144impl Encode<'_, Mssql> for u8 {
145 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
146 buf.push(*self);
147 Ok(IsNull::No)
148 }
149}
150
151impl Decode<'_, Mssql> for u8 {
152 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
153 Ok(u8::try_from(decode_integer(value, "u8")?)?)
154 }
155}
156
157impl Type<Mssql> for i32 {
158 fn type_info() -> MssqlTypeInfo {
159 MssqlTypeInfo::INT
160 }
161
162 fn compatible(ty: &MssqlTypeInfo) -> bool {
163 matches!(
164 ty.kind(),
165 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int
166 )
167 }
168}
169
170impl Encode<'_, Mssql> for i32 {
171 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
172 buf.extend_from_slice(&self.to_le_bytes());
173 Ok(IsNull::No)
174 }
175}
176
177impl Type<Mssql> for bool {
178 fn type_info() -> MssqlTypeInfo {
179 MssqlTypeInfo::BIT
180 }
181
182 fn compatible(ty: &MssqlTypeInfo) -> bool {
183 matches!(ty.kind(), MssqlType::Bit)
184 }
185}
186
187impl Encode<'_, Mssql> for bool {
188 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
189 buf.push(u8::from(*self));
190 Ok(IsNull::No)
191 }
192}
193
194impl Decode<'_, Mssql> for bool {
195 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
196 let bytes = value
197 .as_bytes()
198 .ok_or_else(|| "cannot decode SQL Server NULL as bool".to_owned())?;
199
200 match bytes {
201 [0] => Ok(false),
202 [1] => Ok(true),
203 _ => Err("cannot decode SQL Server bit as bool".into()),
204 }
205 }
206}
207
208impl Type<Mssql> for i16 {
209 fn type_info() -> MssqlTypeInfo {
210 MssqlTypeInfo::SMALLINT
211 }
212
213 fn compatible(ty: &MssqlTypeInfo) -> bool {
214 matches!(ty.kind(), MssqlType::TinyInt | MssqlType::SmallInt)
215 }
216}
217
218impl Encode<'_, Mssql> for i16 {
219 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
220 buf.extend_from_slice(&self.to_le_bytes());
221 Ok(IsNull::No)
222 }
223}
224
225impl Decode<'_, Mssql> for i16 {
226 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
227 Ok(i16::try_from(decode_integer(value, "i16")?)?)
228 }
229}
230
231impl Type<Mssql> for u16 {
232 fn type_info() -> MssqlTypeInfo {
233 MssqlTypeInfo::INT
234 }
235
236 fn compatible(ty: &MssqlTypeInfo) -> bool {
237 matches!(
238 ty.kind(),
239 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
240 )
241 }
242}
243
244impl Encode<'_, Mssql> for u16 {
245 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
246 <i32 as Encode<Mssql>>::encode_by_ref(&i32::from(*self), buf)
247 }
248}
249
250impl Decode<'_, Mssql> for u16 {
251 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
252 Ok(u16::try_from(decode_integer(value, "u16")?)?)
253 }
254}
255
256impl Decode<'_, Mssql> for i32 {
257 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
258 Ok(i32::try_from(decode_integer(value, "i32")?)?)
259 }
260}
261
262impl Type<Mssql> for u32 {
263 fn type_info() -> MssqlTypeInfo {
264 MssqlTypeInfo::BIGINT
265 }
266
267 fn compatible(ty: &MssqlTypeInfo) -> bool {
268 matches!(
269 ty.kind(),
270 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
271 )
272 }
273}
274
275impl Encode<'_, Mssql> for u32 {
276 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
277 <i64 as Encode<Mssql>>::encode_by_ref(&i64::from(*self), buf)
278 }
279}
280
281impl Decode<'_, Mssql> for u32 {
282 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
283 Ok(u32::try_from(decode_integer(value, "u32")?)?)
284 }
285}
286
287impl Type<Mssql> for i64 {
288 fn type_info() -> MssqlTypeInfo {
289 MssqlTypeInfo::BIGINT
290 }
291
292 fn compatible(ty: &MssqlTypeInfo) -> bool {
293 matches!(
294 ty.kind(),
295 MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
296 )
297 }
298}
299
300impl Encode<'_, Mssql> for i64 {
301 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
302 buf.extend_from_slice(&self.to_le_bytes());
303 Ok(IsNull::No)
304 }
305}
306
307impl Decode<'_, Mssql> for i64 {
308 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
309 decode_integer(value, "i64")
310 }
311}
312
313impl Type<Mssql> for f32 {
314 fn type_info() -> MssqlTypeInfo {
315 MssqlTypeInfo::REAL
316 }
317
318 fn compatible(ty: &MssqlTypeInfo) -> bool {
319 matches!(ty.kind(), MssqlType::Real)
320 }
321}
322
323impl Encode<'_, Mssql> for f32 {
324 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
325 buf.extend_from_slice(&self.to_le_bytes());
326 Ok(IsNull::No)
327 }
328}
329
330impl Decode<'_, Mssql> for f32 {
331 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
332 let bytes = value
333 .as_bytes()
334 .ok_or_else(|| "cannot decode SQL Server NULL as f32".to_owned())?;
335
336 match bytes {
337 [a, b, c, d] => Ok(f32::from_le_bytes([*a, *b, *c, *d])),
338 _ => Err("cannot decode SQL Server real as f32".into()),
339 }
340 }
341}
342
343impl Type<Mssql> for f64 {
344 fn type_info() -> MssqlTypeInfo {
345 MssqlTypeInfo::FLOAT
346 }
347
348 fn compatible(ty: &MssqlTypeInfo) -> bool {
349 matches!(ty.kind(), MssqlType::Real | MssqlType::Float)
350 }
351}
352
353impl Decode<'_, Mssql> for f64 {
354 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
355 match value
356 .as_bytes()
357 .ok_or_else(|| "cannot decode SQL Server NULL as f64".to_owned())?
358 {
359 [a, b, c, d] => Ok(f64::from(f32::from_le_bytes([*a, *b, *c, *d]))),
360 [a, b, c, d, e, f, g, h] => Ok(f64::from_le_bytes([*a, *b, *c, *d, *e, *f, *g, *h])),
361 _ => Err("cannot decode SQL Server float as f64".into()),
362 }
363 }
364}
365
366impl Encode<'_, Mssql> for f64 {
367 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
368 buf.extend_from_slice(&self.to_le_bytes());
369 Ok(IsNull::No)
370 }
371}
372
373impl Type<Mssql> for str {
374 fn type_info() -> MssqlTypeInfo {
375 MssqlTypeInfo::NVARCHAR
376 }
377
378 fn compatible(ty: &MssqlTypeInfo) -> bool {
379 matches!(ty.kind(), MssqlType::NVarChar)
380 }
381}
382
383impl Type<Mssql> for String {
384 fn type_info() -> MssqlTypeInfo {
385 <str as Type<Mssql>>::type_info()
386 }
387
388 fn compatible(ty: &MssqlTypeInfo) -> bool {
389 <str as Type<Mssql>>::compatible(ty)
390 }
391}
392
393impl Encode<'_, Mssql> for str {
394 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
395 for unit in self.encode_utf16() {
396 buf.extend_from_slice(&unit.to_le_bytes());
397 }
398
399 Ok(IsNull::No)
400 }
401}
402
403impl<'q> Encode<'q, Mssql> for &'q str {
404 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
405 <str as Encode<Mssql>>::encode_by_ref(*self, buf)
406 }
407}
408
409impl Encode<'_, Mssql> for String {
410 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
411 <str as Encode<Mssql>>::encode_by_ref(self.as_str(), buf)
412 }
413}
414
415impl Decode<'_, Mssql> for String {
416 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
417 let bytes = value
418 .as_bytes()
419 .ok_or_else(|| "cannot decode SQL Server NULL as String".to_owned())?;
420
421 if bytes.len() % 2 != 0 {
422 return Err("cannot decode odd-length SQL Server UTF-16 text".into());
423 }
424
425 let units = bytes
426 .chunks_exact(2)
427 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
428 .collect::<Vec<_>>();
429 Ok(String::from_utf16(&units)?)
430 }
431}
432
433impl Type<Mssql> for [u8] {
434 fn type_info() -> MssqlTypeInfo {
435 MssqlTypeInfo::VARBINARY
436 }
437
438 fn compatible(ty: &MssqlTypeInfo) -> bool {
439 matches!(ty.kind(), MssqlType::VarBinary)
440 }
441}
442
443impl Type<Mssql> for Vec<u8> {
444 fn type_info() -> MssqlTypeInfo {
445 <[u8] as Type<Mssql>>::type_info()
446 }
447
448 fn compatible(ty: &MssqlTypeInfo) -> bool {
449 <[u8] as Type<Mssql>>::compatible(ty)
450 }
451}
452
453impl Encode<'_, Mssql> for [u8] {
454 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
455 buf.extend_from_slice(self);
456 Ok(IsNull::No)
457 }
458}
459
460impl<'q> Encode<'q, Mssql> for &'q [u8] {
461 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
462 <[u8] as Encode<Mssql>>::encode_by_ref(*self, buf)
463 }
464}
465
466impl Encode<'_, Mssql> for Vec<u8> {
467 fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
468 <[u8] as Encode<Mssql>>::encode_by_ref(self.as_slice(), buf)
469 }
470}
471
472impl Decode<'_, Mssql> for Vec<u8> {
473 fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
474 Ok(<&[u8] as Decode<Mssql>>::decode(value)?.to_vec())
475 }
476}
477
478impl<'r> Decode<'r, Mssql> for &'r [u8] {
479 fn decode(value: MssqlValueRef<'r>) -> Result<Self, BoxDynError> {
480 non_null_bytes(value, "bytes")
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn integer_scalars_use_lossless_parameter_types() {
490 assert_eq!(MssqlTypeInfo::SMALLINT, <i8 as Type<Mssql>>::type_info());
491 assert_eq!(MssqlTypeInfo::TINYINT, <u8 as Type<Mssql>>::type_info());
492 assert_eq!(MssqlTypeInfo::INT, <u16 as Type<Mssql>>::type_info());
493 assert_eq!(MssqlTypeInfo::BIGINT, <u32 as Type<Mssql>>::type_info());
494 }
495
496 #[test]
497 fn encodes_unsigned_integer_scalars_without_saturation() {
498 let mut buf = Vec::new();
499 let _ = <u32 as Encode<Mssql>>::encode_by_ref(&u32::MAX, &mut buf).unwrap();
500 assert_eq!(i64::from(u32::MAX).to_le_bytes(), buf.as_slice());
501
502 buf.clear();
503 let _ = <u16 as Encode<Mssql>>::encode_by_ref(&u16::MAX, &mut buf).unwrap();
504 assert_eq!(i32::from(u16::MAX).to_le_bytes(), buf.as_slice());
505 }
506
507 #[test]
508 fn decodes_integer_scalars_with_range_checks() {
509 let value = MssqlValue::new(MssqlTypeInfo::INT, Some(65_535_i32.to_le_bytes().to_vec()));
510 assert_eq!(
511 65_535_u16,
512 <u16 as Decode<Mssql>>::decode(value.as_ref()).unwrap()
513 );
514
515 let negative = MssqlValue::new(MssqlTypeInfo::INT, Some((-1_i32).to_le_bytes().to_vec()));
516 assert!(<u16 as Decode<Mssql>>::decode(negative.as_ref()).is_err());
517
518 let too_large = MssqlValue::new(MssqlTypeInfo::INT, Some(128_i32.to_le_bytes().to_vec()));
519 assert!(<i8 as Decode<Mssql>>::decode(too_large.as_ref()).is_err());
520 }
521
522 #[test]
523 fn decodes_borrowed_bytes() {
524 let value = MssqlValue::new(MssqlTypeInfo::VARBINARY, Some(vec![1, 2, 3, 4]));
525 let bytes = <&[u8] as Decode<Mssql>>::decode(value.as_ref()).unwrap();
526
527 assert_eq!(&[1, 2, 3, 4], bytes);
528 }
529}