Skip to main content

sqlx_sqlserver/
value.rs

1use std::borrow::Cow;
2
3use sqlx_core::decode::Decode;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7use sqlx_core::value::{Value, ValueRef};
8
9use crate::{Mssql, MssqlType, MssqlTypeInfo};
10
11/// Owned SQL Server value skeleton.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MssqlValue {
14    type_info: MssqlTypeInfo,
15    data: Option<Vec<u8>>,
16}
17
18impl MssqlValue {
19    /// Creates an owned value from type information and raw little-endian TDS bytes.
20    pub(crate) fn new(type_info: MssqlTypeInfo, data: Option<Vec<u8>>) -> Self {
21        Self { type_info, data }
22    }
23
24    /// Creates a `NULL` value with the supplied type information.
25    pub fn null(type_info: MssqlTypeInfo) -> Self {
26        Self {
27            type_info,
28            data: None,
29        }
30    }
31}
32
33impl Value for MssqlValue {
34    type Database = Mssql;
35
36    fn as_ref(&self) -> MssqlValueRef<'_> {
37        MssqlValueRef {
38            type_info: &self.type_info,
39            data: self.data.as_deref(),
40        }
41    }
42
43    fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
44        Cow::Borrowed(&self.type_info)
45    }
46
47    fn is_null(&self) -> bool {
48        self.data.is_none()
49    }
50}
51
52/// Borrowed SQL Server value skeleton.
53#[derive(Debug, Clone, Copy)]
54pub struct MssqlValueRef<'r> {
55    type_info: &'r MssqlTypeInfo,
56    data: Option<&'r [u8]>,
57}
58
59impl<'r> ValueRef<'r> for MssqlValueRef<'r> {
60    type Database = Mssql;
61
62    fn to_owned(&self) -> MssqlValue {
63        MssqlValue {
64            type_info: self.type_info.clone(),
65            data: self.data.map(ToOwned::to_owned),
66        }
67    }
68
69    fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
70        Cow::Borrowed(self.type_info)
71    }
72
73    fn is_null(&self) -> bool {
74        self.data.is_none()
75    }
76}
77
78impl<'r> MssqlValueRef<'r> {
79    pub(crate) fn as_bytes(&self) -> Option<&'r [u8]> {
80        self.data
81    }
82}
83
84fn non_null_bytes<'r>(value: MssqlValueRef<'r>, rust_type: &str) -> Result<&'r [u8], BoxDynError> {
85    value
86        .as_bytes()
87        .ok_or_else(|| format!("cannot decode SQL Server NULL as {rust_type}").into())
88}
89
90fn decode_integer(value: MssqlValueRef<'_>, rust_type: &str) -> Result<i64, BoxDynError> {
91    let bytes = non_null_bytes(value, rust_type)?;
92
93    match bytes.len() {
94        1 => Ok(i64::from(bytes[0])),
95        2 => Ok(i64::from(i16::from_le_bytes([bytes[0], bytes[1]]))),
96        4 => Ok(i64::from(i32::from_le_bytes([
97            bytes[0], bytes[1], bytes[2], bytes[3],
98        ]))),
99        8 => Ok(i64::from_le_bytes([
100            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
101        ])),
102        len => Err(format!("cannot decode {len}-byte SQL Server integer as {rust_type}").into()),
103    }
104}
105
106impl Type<Mssql> for i8 {
107    fn type_info() -> MssqlTypeInfo {
108        MssqlTypeInfo::SMALLINT
109    }
110
111    fn compatible(ty: &MssqlTypeInfo) -> bool {
112        matches!(
113            ty.kind(),
114            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
115        )
116    }
117}
118
119impl Encode<'_, Mssql> for i8 {
120    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
121        <i16 as Encode<Mssql>>::encode_by_ref(&i16::from(*self), buf)
122    }
123}
124
125impl Decode<'_, Mssql> for i8 {
126    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
127        Ok(i8::try_from(decode_integer(value, "i8")?)?)
128    }
129}
130
131impl Type<Mssql> for u8 {
132    fn type_info() -> MssqlTypeInfo {
133        MssqlTypeInfo::TINYINT
134    }
135
136    fn compatible(ty: &MssqlTypeInfo) -> bool {
137        matches!(
138            ty.kind(),
139            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
140        )
141    }
142}
143
144impl Encode<'_, Mssql> for u8 {
145    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
146        buf.push(*self);
147        Ok(IsNull::No)
148    }
149}
150
151impl Decode<'_, Mssql> for u8 {
152    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
153        Ok(u8::try_from(decode_integer(value, "u8")?)?)
154    }
155}
156
157impl Type<Mssql> for i32 {
158    fn type_info() -> MssqlTypeInfo {
159        MssqlTypeInfo::INT
160    }
161
162    fn compatible(ty: &MssqlTypeInfo) -> bool {
163        matches!(
164            ty.kind(),
165            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int
166        )
167    }
168}
169
170impl Encode<'_, Mssql> for i32 {
171    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
172        buf.extend_from_slice(&self.to_le_bytes());
173        Ok(IsNull::No)
174    }
175}
176
177impl Type<Mssql> for bool {
178    fn type_info() -> MssqlTypeInfo {
179        MssqlTypeInfo::BIT
180    }
181
182    fn compatible(ty: &MssqlTypeInfo) -> bool {
183        matches!(ty.kind(), MssqlType::Bit)
184    }
185}
186
187impl Encode<'_, Mssql> for bool {
188    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
189        buf.push(u8::from(*self));
190        Ok(IsNull::No)
191    }
192}
193
194impl Decode<'_, Mssql> for bool {
195    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
196        let bytes = value
197            .as_bytes()
198            .ok_or_else(|| "cannot decode SQL Server NULL as bool".to_owned())?;
199
200        match bytes {
201            [0] => Ok(false),
202            [1] => Ok(true),
203            _ => Err("cannot decode SQL Server bit as bool".into()),
204        }
205    }
206}
207
208impl Type<Mssql> for i16 {
209    fn type_info() -> MssqlTypeInfo {
210        MssqlTypeInfo::SMALLINT
211    }
212
213    fn compatible(ty: &MssqlTypeInfo) -> bool {
214        matches!(ty.kind(), MssqlType::TinyInt | MssqlType::SmallInt)
215    }
216}
217
218impl Encode<'_, Mssql> for i16 {
219    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
220        buf.extend_from_slice(&self.to_le_bytes());
221        Ok(IsNull::No)
222    }
223}
224
225impl Decode<'_, Mssql> for i16 {
226    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
227        Ok(i16::try_from(decode_integer(value, "i16")?)?)
228    }
229}
230
231impl Type<Mssql> for u16 {
232    fn type_info() -> MssqlTypeInfo {
233        MssqlTypeInfo::INT
234    }
235
236    fn compatible(ty: &MssqlTypeInfo) -> bool {
237        matches!(
238            ty.kind(),
239            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
240        )
241    }
242}
243
244impl Encode<'_, Mssql> for u16 {
245    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
246        <i32 as Encode<Mssql>>::encode_by_ref(&i32::from(*self), buf)
247    }
248}
249
250impl Decode<'_, Mssql> for u16 {
251    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
252        Ok(u16::try_from(decode_integer(value, "u16")?)?)
253    }
254}
255
256impl Decode<'_, Mssql> for i32 {
257    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
258        Ok(i32::try_from(decode_integer(value, "i32")?)?)
259    }
260}
261
262impl Type<Mssql> for u32 {
263    fn type_info() -> MssqlTypeInfo {
264        MssqlTypeInfo::BIGINT
265    }
266
267    fn compatible(ty: &MssqlTypeInfo) -> bool {
268        matches!(
269            ty.kind(),
270            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
271        )
272    }
273}
274
275impl Encode<'_, Mssql> for u32 {
276    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
277        <i64 as Encode<Mssql>>::encode_by_ref(&i64::from(*self), buf)
278    }
279}
280
281impl Decode<'_, Mssql> for u32 {
282    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
283        Ok(u32::try_from(decode_integer(value, "u32")?)?)
284    }
285}
286
287impl Type<Mssql> for i64 {
288    fn type_info() -> MssqlTypeInfo {
289        MssqlTypeInfo::BIGINT
290    }
291
292    fn compatible(ty: &MssqlTypeInfo) -> bool {
293        matches!(
294            ty.kind(),
295            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
296        )
297    }
298}
299
300impl Encode<'_, Mssql> for i64 {
301    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
302        buf.extend_from_slice(&self.to_le_bytes());
303        Ok(IsNull::No)
304    }
305}
306
307impl Decode<'_, Mssql> for i64 {
308    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
309        decode_integer(value, "i64")
310    }
311}
312
313impl Type<Mssql> for f32 {
314    fn type_info() -> MssqlTypeInfo {
315        MssqlTypeInfo::REAL
316    }
317
318    fn compatible(ty: &MssqlTypeInfo) -> bool {
319        matches!(ty.kind(), MssqlType::Real)
320    }
321}
322
323impl Encode<'_, Mssql> for f32 {
324    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
325        buf.extend_from_slice(&self.to_le_bytes());
326        Ok(IsNull::No)
327    }
328}
329
330impl Decode<'_, Mssql> for f32 {
331    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
332        let bytes = value
333            .as_bytes()
334            .ok_or_else(|| "cannot decode SQL Server NULL as f32".to_owned())?;
335
336        match bytes {
337            [a, b, c, d] => Ok(f32::from_le_bytes([*a, *b, *c, *d])),
338            _ => Err("cannot decode SQL Server real as f32".into()),
339        }
340    }
341}
342
343impl Type<Mssql> for f64 {
344    fn type_info() -> MssqlTypeInfo {
345        MssqlTypeInfo::FLOAT
346    }
347
348    fn compatible(ty: &MssqlTypeInfo) -> bool {
349        matches!(ty.kind(), MssqlType::Real | MssqlType::Float)
350    }
351}
352
353impl Decode<'_, Mssql> for f64 {
354    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
355        match value
356            .as_bytes()
357            .ok_or_else(|| "cannot decode SQL Server NULL as f64".to_owned())?
358        {
359            [a, b, c, d] => Ok(f64::from(f32::from_le_bytes([*a, *b, *c, *d]))),
360            [a, b, c, d, e, f, g, h] => Ok(f64::from_le_bytes([*a, *b, *c, *d, *e, *f, *g, *h])),
361            _ => Err("cannot decode SQL Server float as f64".into()),
362        }
363    }
364}
365
366impl Encode<'_, Mssql> for f64 {
367    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
368        buf.extend_from_slice(&self.to_le_bytes());
369        Ok(IsNull::No)
370    }
371}
372
373impl Type<Mssql> for str {
374    fn type_info() -> MssqlTypeInfo {
375        MssqlTypeInfo::NVARCHAR
376    }
377
378    fn compatible(ty: &MssqlTypeInfo) -> bool {
379        matches!(ty.kind(), MssqlType::NVarChar | MssqlType::VarChar)
380    }
381}
382
383impl Type<Mssql> for String {
384    fn type_info() -> MssqlTypeInfo {
385        <str as Type<Mssql>>::type_info()
386    }
387
388    fn compatible(ty: &MssqlTypeInfo) -> bool {
389        <str as Type<Mssql>>::compatible(ty)
390    }
391}
392
393impl Encode<'_, Mssql> for str {
394    fn produces(&self) -> Option<MssqlTypeInfo> {
395        Some(MssqlTypeInfo::with_size(
396            MssqlType::NVarChar,
397            nvarchar_parameter_size(self),
398        ))
399    }
400
401    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
402        for unit in self.encode_utf16() {
403            buf.extend_from_slice(&unit.to_le_bytes());
404        }
405
406        Ok(IsNull::No)
407    }
408}
409
410impl<'q> Encode<'q, Mssql> for &'q str {
411    fn produces(&self) -> Option<MssqlTypeInfo> {
412        <str as Encode<Mssql>>::produces(*self)
413    }
414
415    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
416        <str as Encode<Mssql>>::encode_by_ref(*self, buf)
417    }
418}
419
420impl Encode<'_, Mssql> for String {
421    fn produces(&self) -> Option<MssqlTypeInfo> {
422        <str as Encode<Mssql>>::produces(self.as_str())
423    }
424
425    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
426        <str as Encode<Mssql>>::encode_by_ref(self.as_str(), buf)
427    }
428}
429
430impl Decode<'_, Mssql> for String {
431    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
432        let bytes = value
433            .as_bytes()
434            .ok_or_else(|| "cannot decode SQL Server NULL as String".to_owned())?;
435
436        if matches!(value.type_info.kind(), MssqlType::VarChar) {
437            return Ok(std::str::from_utf8(bytes)?.to_owned());
438        }
439
440        if bytes.len() % 2 != 0 {
441            return Err("cannot decode odd-length SQL Server UTF-16 text".into());
442        }
443
444        let units = bytes
445            .chunks_exact(2)
446            .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
447            .collect::<Vec<_>>();
448        Ok(String::from_utf16(&units)?)
449    }
450}
451
452impl Type<Mssql> for [u8] {
453    fn type_info() -> MssqlTypeInfo {
454        MssqlTypeInfo::VARBINARY
455    }
456
457    fn compatible(ty: &MssqlTypeInfo) -> bool {
458        matches!(ty.kind(), MssqlType::VarBinary)
459    }
460}
461
462impl Type<Mssql> for Vec<u8> {
463    fn type_info() -> MssqlTypeInfo {
464        <[u8] as Type<Mssql>>::type_info()
465    }
466
467    fn compatible(ty: &MssqlTypeInfo) -> bool {
468        <[u8] as Type<Mssql>>::compatible(ty)
469    }
470}
471
472impl Encode<'_, Mssql> for [u8] {
473    fn produces(&self) -> Option<MssqlTypeInfo> {
474        Some(MssqlTypeInfo::with_size(
475            MssqlType::VarBinary,
476            varbinary_parameter_size(self.len()),
477        ))
478    }
479
480    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
481        buf.extend_from_slice(self);
482        Ok(IsNull::No)
483    }
484}
485
486impl<'q> Encode<'q, Mssql> for &'q [u8] {
487    fn produces(&self) -> Option<MssqlTypeInfo> {
488        <[u8] as Encode<Mssql>>::produces(*self)
489    }
490
491    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
492        <[u8] as Encode<Mssql>>::encode_by_ref(*self, buf)
493    }
494}
495
496impl Encode<'_, Mssql> for Vec<u8> {
497    fn produces(&self) -> Option<MssqlTypeInfo> {
498        <[u8] as Encode<Mssql>>::produces(self.as_slice())
499    }
500
501    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
502        <[u8] as Encode<Mssql>>::encode_by_ref(self.as_slice(), buf)
503    }
504}
505
506impl Decode<'_, Mssql> for Vec<u8> {
507    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
508        Ok(<&[u8] as Decode<Mssql>>::decode(value)?.to_vec())
509    }
510}
511
512impl<'r> Decode<'r, Mssql> for &'r [u8] {
513    fn decode(value: MssqlValueRef<'r>) -> Result<Self, BoxDynError> {
514        non_null_bytes(value, "bytes")
515    }
516}
517
518fn nvarchar_parameter_size(value: &str) -> u16 {
519    let bytes = value.encode_utf16().count().saturating_mul(2);
520    if bytes > 8000 {
521        u16::MAX
522    } else {
523        u16::try_from(std::cmp::max(2, bytes)).unwrap_or(u16::MAX)
524    }
525}
526
527fn varbinary_parameter_size(len: usize) -> u16 {
528    if len > 8000 {
529        u16::MAX
530    } else {
531        u16::try_from(std::cmp::max(1, len)).unwrap_or(u16::MAX)
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn integer_scalars_use_lossless_parameter_types() {
541        assert_eq!(MssqlTypeInfo::SMALLINT, <i8 as Type<Mssql>>::type_info());
542        assert_eq!(MssqlTypeInfo::TINYINT, <u8 as Type<Mssql>>::type_info());
543        assert_eq!(MssqlTypeInfo::INT, <u16 as Type<Mssql>>::type_info());
544        assert_eq!(MssqlTypeInfo::BIGINT, <u32 as Type<Mssql>>::type_info());
545    }
546
547    #[test]
548    fn encodes_unsigned_integer_scalars_without_saturation() {
549        let mut buf = Vec::new();
550        let _ = <u32 as Encode<Mssql>>::encode_by_ref(&u32::MAX, &mut buf).unwrap();
551        assert_eq!(i64::from(u32::MAX).to_le_bytes(), buf.as_slice());
552
553        buf.clear();
554        let _ = <u16 as Encode<Mssql>>::encode_by_ref(&u16::MAX, &mut buf).unwrap();
555        assert_eq!(i32::from(u16::MAX).to_le_bytes(), buf.as_slice());
556    }
557
558    #[test]
559    fn decodes_integer_scalars_with_range_checks() {
560        let value = MssqlValue::new(MssqlTypeInfo::INT, Some(65_535_i32.to_le_bytes().to_vec()));
561        assert_eq!(
562            65_535_u16,
563            <u16 as Decode<Mssql>>::decode(value.as_ref()).unwrap()
564        );
565
566        let negative = MssqlValue::new(MssqlTypeInfo::INT, Some((-1_i32).to_le_bytes().to_vec()));
567        assert!(<u16 as Decode<Mssql>>::decode(negative.as_ref()).is_err());
568
569        let too_large = MssqlValue::new(MssqlTypeInfo::INT, Some(128_i32.to_le_bytes().to_vec()));
570        assert!(<i8 as Decode<Mssql>>::decode(too_large.as_ref()).is_err());
571    }
572
573    #[test]
574    fn decodes_borrowed_bytes() {
575        let value = MssqlValue::new(MssqlTypeInfo::VARBINARY, Some(vec![1, 2, 3, 4]));
576        let bytes = <&[u8] as Decode<Mssql>>::decode(value.as_ref()).unwrap();
577
578        assert_eq!(&[1, 2, 3, 4], bytes);
579    }
580
581    #[test]
582    fn decodes_ascii_varchar_as_utf8() {
583        let value = MssqlValue::new(MssqlTypeInfo::VARCHAR, Some(b"hello".to_vec()));
584
585        assert_eq!(
586            "hello",
587            <String as Decode<Mssql>>::decode(value.as_ref()).unwrap()
588        );
589    }
590
591    #[test]
592    fn decodes_nvarchar_as_utf16() {
593        let mut data = Vec::new();
594        for unit in "hello".encode_utf16() {
595            data.extend_from_slice(&unit.to_le_bytes());
596        }
597
598        let value = MssqlValue::new(MssqlTypeInfo::NVARCHAR, Some(data));
599
600        assert_eq!(
601            "hello",
602            <String as Decode<Mssql>>::decode(value.as_ref()).unwrap()
603        );
604    }
605}