Skip to main content

sqlx_sqlserver/
value.rs

1use std::borrow::Cow;
2
3use sqlx_core::decode::Decode;
4use sqlx_core::encode::{Encode, IsNull};
5use sqlx_core::error::BoxDynError;
6use sqlx_core::types::Type;
7use sqlx_core::value::{Value, ValueRef};
8
9use crate::{Mssql, MssqlType, MssqlTypeInfo};
10
11/// Owned SQL Server value skeleton.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MssqlValue {
14    type_info: MssqlTypeInfo,
15    data: Option<Vec<u8>>,
16}
17
18impl MssqlValue {
19    /// Creates an owned value from type information and raw little-endian TDS bytes.
20    pub(crate) fn new(type_info: MssqlTypeInfo, data: Option<Vec<u8>>) -> Self {
21        Self { type_info, data }
22    }
23
24    /// Creates a `NULL` value with the supplied type information.
25    pub fn null(type_info: MssqlTypeInfo) -> Self {
26        Self {
27            type_info,
28            data: None,
29        }
30    }
31}
32
33impl Value for MssqlValue {
34    type Database = Mssql;
35
36    fn as_ref(&self) -> MssqlValueRef<'_> {
37        MssqlValueRef {
38            type_info: &self.type_info,
39            data: self.data.as_deref(),
40        }
41    }
42
43    fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
44        Cow::Borrowed(&self.type_info)
45    }
46
47    fn is_null(&self) -> bool {
48        self.data.is_none()
49    }
50}
51
52/// Borrowed SQL Server value skeleton.
53#[derive(Debug, Clone, Copy)]
54pub struct MssqlValueRef<'r> {
55    type_info: &'r MssqlTypeInfo,
56    data: Option<&'r [u8]>,
57}
58
59impl<'r> ValueRef<'r> for MssqlValueRef<'r> {
60    type Database = Mssql;
61
62    fn to_owned(&self) -> MssqlValue {
63        MssqlValue {
64            type_info: self.type_info.clone(),
65            data: self.data.map(ToOwned::to_owned),
66        }
67    }
68
69    fn type_info(&self) -> Cow<'_, MssqlTypeInfo> {
70        Cow::Borrowed(self.type_info)
71    }
72
73    fn is_null(&self) -> bool {
74        self.data.is_none()
75    }
76}
77
78impl<'r> MssqlValueRef<'r> {
79    pub(crate) fn as_bytes(&self) -> Option<&'r [u8]> {
80        self.data
81    }
82}
83
84fn non_null_bytes<'r>(value: MssqlValueRef<'r>, rust_type: &str) -> Result<&'r [u8], BoxDynError> {
85    value
86        .as_bytes()
87        .ok_or_else(|| format!("cannot decode SQL Server NULL as {rust_type}").into())
88}
89
90fn decode_integer(value: MssqlValueRef<'_>, rust_type: &str) -> Result<i64, BoxDynError> {
91    let bytes = non_null_bytes(value, rust_type)?;
92
93    match bytes.len() {
94        1 => Ok(i64::from(bytes[0])),
95        2 => Ok(i64::from(i16::from_le_bytes([bytes[0], bytes[1]]))),
96        4 => Ok(i64::from(i32::from_le_bytes([
97            bytes[0], bytes[1], bytes[2], bytes[3],
98        ]))),
99        8 => Ok(i64::from_le_bytes([
100            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
101        ])),
102        len => Err(format!("cannot decode {len}-byte SQL Server integer as {rust_type}").into()),
103    }
104}
105
106impl Type<Mssql> for i8 {
107    fn type_info() -> MssqlTypeInfo {
108        MssqlTypeInfo::SMALLINT
109    }
110
111    fn compatible(ty: &MssqlTypeInfo) -> bool {
112        matches!(
113            ty.kind(),
114            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
115        )
116    }
117}
118
119impl Encode<'_, Mssql> for i8 {
120    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
121        <i16 as Encode<Mssql>>::encode_by_ref(&i16::from(*self), buf)
122    }
123}
124
125impl Decode<'_, Mssql> for i8 {
126    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
127        Ok(i8::try_from(decode_integer(value, "i8")?)?)
128    }
129}
130
131impl Type<Mssql> for u8 {
132    fn type_info() -> MssqlTypeInfo {
133        MssqlTypeInfo::TINYINT
134    }
135
136    fn compatible(ty: &MssqlTypeInfo) -> bool {
137        matches!(
138            ty.kind(),
139            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
140        )
141    }
142}
143
144impl Encode<'_, Mssql> for u8 {
145    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
146        buf.push(*self);
147        Ok(IsNull::No)
148    }
149}
150
151impl Decode<'_, Mssql> for u8 {
152    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
153        Ok(u8::try_from(decode_integer(value, "u8")?)?)
154    }
155}
156
157impl Type<Mssql> for i32 {
158    fn type_info() -> MssqlTypeInfo {
159        MssqlTypeInfo::INT
160    }
161
162    fn compatible(ty: &MssqlTypeInfo) -> bool {
163        matches!(
164            ty.kind(),
165            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int
166        )
167    }
168}
169
170impl Encode<'_, Mssql> for i32 {
171    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
172        buf.extend_from_slice(&self.to_le_bytes());
173        Ok(IsNull::No)
174    }
175}
176
177impl Type<Mssql> for bool {
178    fn type_info() -> MssqlTypeInfo {
179        MssqlTypeInfo::BIT
180    }
181
182    fn compatible(ty: &MssqlTypeInfo) -> bool {
183        matches!(ty.kind(), MssqlType::Bit)
184    }
185}
186
187impl Encode<'_, Mssql> for bool {
188    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
189        buf.push(u8::from(*self));
190        Ok(IsNull::No)
191    }
192}
193
194impl Decode<'_, Mssql> for bool {
195    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
196        let bytes = value
197            .as_bytes()
198            .ok_or_else(|| "cannot decode SQL Server NULL as bool".to_owned())?;
199
200        match bytes {
201            [0] => Ok(false),
202            [1] => Ok(true),
203            _ => Err("cannot decode SQL Server bit as bool".into()),
204        }
205    }
206}
207
208impl Type<Mssql> for i16 {
209    fn type_info() -> MssqlTypeInfo {
210        MssqlTypeInfo::SMALLINT
211    }
212
213    fn compatible(ty: &MssqlTypeInfo) -> bool {
214        matches!(ty.kind(), MssqlType::TinyInt | MssqlType::SmallInt)
215    }
216}
217
218impl Encode<'_, Mssql> for i16 {
219    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
220        buf.extend_from_slice(&self.to_le_bytes());
221        Ok(IsNull::No)
222    }
223}
224
225impl Decode<'_, Mssql> for i16 {
226    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
227        Ok(i16::try_from(decode_integer(value, "i16")?)?)
228    }
229}
230
231impl Type<Mssql> for u16 {
232    fn type_info() -> MssqlTypeInfo {
233        MssqlTypeInfo::INT
234    }
235
236    fn compatible(ty: &MssqlTypeInfo) -> bool {
237        matches!(
238            ty.kind(),
239            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
240        )
241    }
242}
243
244impl Encode<'_, Mssql> for u16 {
245    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
246        <i32 as Encode<Mssql>>::encode_by_ref(&i32::from(*self), buf)
247    }
248}
249
250impl Decode<'_, Mssql> for u16 {
251    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
252        Ok(u16::try_from(decode_integer(value, "u16")?)?)
253    }
254}
255
256impl Decode<'_, Mssql> for i32 {
257    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
258        Ok(i32::try_from(decode_integer(value, "i32")?)?)
259    }
260}
261
262impl Type<Mssql> for u32 {
263    fn type_info() -> MssqlTypeInfo {
264        MssqlTypeInfo::BIGINT
265    }
266
267    fn compatible(ty: &MssqlTypeInfo) -> bool {
268        matches!(
269            ty.kind(),
270            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
271        )
272    }
273}
274
275impl Encode<'_, Mssql> for u32 {
276    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
277        <i64 as Encode<Mssql>>::encode_by_ref(&i64::from(*self), buf)
278    }
279}
280
281impl Decode<'_, Mssql> for u32 {
282    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
283        Ok(u32::try_from(decode_integer(value, "u32")?)?)
284    }
285}
286
287impl Type<Mssql> for i64 {
288    fn type_info() -> MssqlTypeInfo {
289        MssqlTypeInfo::BIGINT
290    }
291
292    fn compatible(ty: &MssqlTypeInfo) -> bool {
293        matches!(
294            ty.kind(),
295            MssqlType::TinyInt | MssqlType::SmallInt | MssqlType::Int | MssqlType::BigInt
296        )
297    }
298}
299
300impl Encode<'_, Mssql> for i64 {
301    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
302        buf.extend_from_slice(&self.to_le_bytes());
303        Ok(IsNull::No)
304    }
305}
306
307impl Decode<'_, Mssql> for i64 {
308    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
309        decode_integer(value, "i64")
310    }
311}
312
313impl Type<Mssql> for f32 {
314    fn type_info() -> MssqlTypeInfo {
315        MssqlTypeInfo::REAL
316    }
317
318    fn compatible(ty: &MssqlTypeInfo) -> bool {
319        matches!(ty.kind(), MssqlType::Real)
320    }
321}
322
323impl Encode<'_, Mssql> for f32 {
324    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
325        buf.extend_from_slice(&self.to_le_bytes());
326        Ok(IsNull::No)
327    }
328}
329
330impl Decode<'_, Mssql> for f32 {
331    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
332        let bytes = value
333            .as_bytes()
334            .ok_or_else(|| "cannot decode SQL Server NULL as f32".to_owned())?;
335
336        match bytes {
337            [a, b, c, d] => Ok(f32::from_le_bytes([*a, *b, *c, *d])),
338            _ => Err("cannot decode SQL Server real as f32".into()),
339        }
340    }
341}
342
343impl Type<Mssql> for f64 {
344    fn type_info() -> MssqlTypeInfo {
345        MssqlTypeInfo::FLOAT
346    }
347
348    fn compatible(ty: &MssqlTypeInfo) -> bool {
349        matches!(ty.kind(), MssqlType::Real | MssqlType::Float)
350    }
351}
352
353impl Decode<'_, Mssql> for f64 {
354    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
355        match value
356            .as_bytes()
357            .ok_or_else(|| "cannot decode SQL Server NULL as f64".to_owned())?
358        {
359            [a, b, c, d] => Ok(f64::from(f32::from_le_bytes([*a, *b, *c, *d]))),
360            [a, b, c, d, e, f, g, h] => Ok(f64::from_le_bytes([*a, *b, *c, *d, *e, *f, *g, *h])),
361            _ => Err("cannot decode SQL Server float as f64".into()),
362        }
363    }
364}
365
366impl Encode<'_, Mssql> for f64 {
367    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
368        buf.extend_from_slice(&self.to_le_bytes());
369        Ok(IsNull::No)
370    }
371}
372
373impl Type<Mssql> for str {
374    fn type_info() -> MssqlTypeInfo {
375        MssqlTypeInfo::NVARCHAR
376    }
377
378    fn compatible(ty: &MssqlTypeInfo) -> bool {
379        matches!(ty.kind(), MssqlType::NVarChar)
380    }
381}
382
383impl Type<Mssql> for String {
384    fn type_info() -> MssqlTypeInfo {
385        <str as Type<Mssql>>::type_info()
386    }
387
388    fn compatible(ty: &MssqlTypeInfo) -> bool {
389        <str as Type<Mssql>>::compatible(ty)
390    }
391}
392
393impl Encode<'_, Mssql> for str {
394    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
395        for unit in self.encode_utf16() {
396            buf.extend_from_slice(&unit.to_le_bytes());
397        }
398
399        Ok(IsNull::No)
400    }
401}
402
403impl<'q> Encode<'q, Mssql> for &'q str {
404    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
405        <str as Encode<Mssql>>::encode_by_ref(*self, buf)
406    }
407}
408
409impl Encode<'_, Mssql> for String {
410    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
411        <str as Encode<Mssql>>::encode_by_ref(self.as_str(), buf)
412    }
413}
414
415impl Decode<'_, Mssql> for String {
416    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
417        let bytes = value
418            .as_bytes()
419            .ok_or_else(|| "cannot decode SQL Server NULL as String".to_owned())?;
420
421        if bytes.len() % 2 != 0 {
422            return Err("cannot decode odd-length SQL Server UTF-16 text".into());
423        }
424
425        let units = bytes
426            .chunks_exact(2)
427            .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
428            .collect::<Vec<_>>();
429        Ok(String::from_utf16(&units)?)
430    }
431}
432
433impl Type<Mssql> for [u8] {
434    fn type_info() -> MssqlTypeInfo {
435        MssqlTypeInfo::VARBINARY
436    }
437
438    fn compatible(ty: &MssqlTypeInfo) -> bool {
439        matches!(ty.kind(), MssqlType::VarBinary)
440    }
441}
442
443impl Type<Mssql> for Vec<u8> {
444    fn type_info() -> MssqlTypeInfo {
445        <[u8] as Type<Mssql>>::type_info()
446    }
447
448    fn compatible(ty: &MssqlTypeInfo) -> bool {
449        <[u8] as Type<Mssql>>::compatible(ty)
450    }
451}
452
453impl Encode<'_, Mssql> for [u8] {
454    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
455        buf.extend_from_slice(self);
456        Ok(IsNull::No)
457    }
458}
459
460impl<'q> Encode<'q, Mssql> for &'q [u8] {
461    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
462        <[u8] as Encode<Mssql>>::encode_by_ref(*self, buf)
463    }
464}
465
466impl Encode<'_, Mssql> for Vec<u8> {
467    fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
468        <[u8] as Encode<Mssql>>::encode_by_ref(self.as_slice(), buf)
469    }
470}
471
472impl Decode<'_, Mssql> for Vec<u8> {
473    fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
474        Ok(<&[u8] as Decode<Mssql>>::decode(value)?.to_vec())
475    }
476}
477
478impl<'r> Decode<'r, Mssql> for &'r [u8] {
479    fn decode(value: MssqlValueRef<'r>) -> Result<Self, BoxDynError> {
480        non_null_bytes(value, "bytes")
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn integer_scalars_use_lossless_parameter_types() {
490        assert_eq!(MssqlTypeInfo::SMALLINT, <i8 as Type<Mssql>>::type_info());
491        assert_eq!(MssqlTypeInfo::TINYINT, <u8 as Type<Mssql>>::type_info());
492        assert_eq!(MssqlTypeInfo::INT, <u16 as Type<Mssql>>::type_info());
493        assert_eq!(MssqlTypeInfo::BIGINT, <u32 as Type<Mssql>>::type_info());
494    }
495
496    #[test]
497    fn encodes_unsigned_integer_scalars_without_saturation() {
498        let mut buf = Vec::new();
499        let _ = <u32 as Encode<Mssql>>::encode_by_ref(&u32::MAX, &mut buf).unwrap();
500        assert_eq!(i64::from(u32::MAX).to_le_bytes(), buf.as_slice());
501
502        buf.clear();
503        let _ = <u16 as Encode<Mssql>>::encode_by_ref(&u16::MAX, &mut buf).unwrap();
504        assert_eq!(i32::from(u16::MAX).to_le_bytes(), buf.as_slice());
505    }
506
507    #[test]
508    fn decodes_integer_scalars_with_range_checks() {
509        let value = MssqlValue::new(MssqlTypeInfo::INT, Some(65_535_i32.to_le_bytes().to_vec()));
510        assert_eq!(
511            65_535_u16,
512            <u16 as Decode<Mssql>>::decode(value.as_ref()).unwrap()
513        );
514
515        let negative = MssqlValue::new(MssqlTypeInfo::INT, Some((-1_i32).to_le_bytes().to_vec()));
516        assert!(<u16 as Decode<Mssql>>::decode(negative.as_ref()).is_err());
517
518        let too_large = MssqlValue::new(MssqlTypeInfo::INT, Some(128_i32.to_le_bytes().to_vec()));
519        assert!(<i8 as Decode<Mssql>>::decode(too_large.as_ref()).is_err());
520    }
521
522    #[test]
523    fn decodes_borrowed_bytes() {
524        let value = MssqlValue::new(MssqlTypeInfo::VARBINARY, Some(vec![1, 2, 3, 4]));
525        let bytes = <&[u8] as Decode<Mssql>>::decode(value.as_ref()).unwrap();
526
527        assert_eq!(&[1, 2, 3, 4], bytes);
528    }
529}