Skip to main content

sql_orm_tiberius/
row.rs

1use crate::error::{TiberiusErrorContext, map_tiberius_error};
2use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
3use rust_decimal::Decimal;
4use sql_orm_core::{OrmError, Row as OrmRow, SqlValue};
5use tiberius::{ColumnType, Row};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy)]
9pub struct MssqlRow<'a> {
10    inner: &'a Row,
11}
12
13impl<'a> MssqlRow<'a> {
14    pub fn new(inner: &'a Row) -> Self {
15        Self { inner }
16    }
17
18    pub fn inner(&self) -> &'a Row {
19        self.inner
20    }
21}
22
23impl OrmRow for MssqlRow<'_> {
24    fn try_get(&self, column: &str) -> Result<Option<SqlValue>, OrmError> {
25        let Some((index, column_type)) =
26            self.inner
27                .columns()
28                .iter()
29                .enumerate()
30                .find_map(|(index, metadata)| {
31                    (metadata.name() == column).then_some((index, metadata.column_type()))
32                })
33        else {
34            return Ok(None);
35        };
36
37        read_sql_value(self.inner, index, column_type).map(Some)
38    }
39}
40
41fn read_sql_value(row: &Row, index: usize, column_type: ColumnType) -> Result<SqlValue, OrmError> {
42    if let Some(value) = static_sql_value(column_type) {
43        return Ok(value);
44    }
45
46    if let Some(error) = unsupported_column_type_error(column_type) {
47        return Err(error);
48    }
49
50    match column_type {
51        ColumnType::Bit | ColumnType::Bitn => {
52            read_typed(row, index, |value: bool| SqlValue::Bool(value))
53        }
54        ColumnType::Int1 => read_typed(row, index, |value: u8| SqlValue::I32(i32::from(value))),
55        ColumnType::Int2 => read_typed(row, index, |value: i16| SqlValue::I32(i32::from(value))),
56        ColumnType::Int4 => read_typed(row, index, |value: i32| SqlValue::I32(value)),
57        ColumnType::Int8 => read_typed(row, index, |value: i64| SqlValue::I64(value)),
58        ColumnType::Intn => read_intn(row, index),
59        ColumnType::Float4 => read_typed(row, index, |value: f32| SqlValue::F64(f64::from(value))),
60        ColumnType::Float8 | ColumnType::Floatn | ColumnType::Money | ColumnType::Money4 => {
61            read_typed(row, index, |value: f64| SqlValue::F64(value))
62        }
63        ColumnType::Guid => read_typed(row, index, |value: Uuid| SqlValue::Uuid(value)),
64        ColumnType::Decimaln | ColumnType::Numericn => {
65            read_typed(row, index, |value: Decimal| SqlValue::Decimal(value))
66        }
67        ColumnType::Daten => read_typed(row, index, |value: NaiveDate| SqlValue::Date(value)),
68        ColumnType::Timen => read_typed(row, index, |value: NaiveTime| SqlValue::Time(value)),
69        ColumnType::Datetime
70        | ColumnType::Datetime4
71        | ColumnType::Datetimen
72        | ColumnType::Datetime2 => {
73            read_typed(row, index, |value: NaiveDateTime| SqlValue::DateTime(value))
74        }
75        ColumnType::DatetimeOffsetn => read_typed(row, index, |value: DateTime<FixedOffset>| {
76            SqlValue::DateTimeOffset(value)
77        }),
78        ColumnType::BigVarChar
79        | ColumnType::BigChar
80        | ColumnType::NVarchar
81        | ColumnType::NChar
82        | ColumnType::Text
83        | ColumnType::NText => read_string(row, index),
84        ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => read_bytes(row, index),
85        ColumnType::Null | ColumnType::Xml | ColumnType::Udt | ColumnType::SSVariant => {
86            unreachable!("special-case column type should have returned early")
87        }
88    }
89}
90
91fn static_sql_value(column_type: ColumnType) -> Option<SqlValue> {
92    match column_type {
93        ColumnType::Null => Some(SqlValue::Null),
94        _ => None,
95    }
96}
97
98fn unsupported_column_type_error(column_type: ColumnType) -> Option<OrmError> {
99    match column_type {
100        ColumnType::Xml | ColumnType::Udt | ColumnType::SSVariant => Some(OrmError::mapping(
101            "unsupported SQL Server column type in MssqlRow",
102        )),
103        _ => None,
104    }
105}
106
107fn read_typed<T>(
108    row: &Row,
109    index: usize,
110    map: impl FnOnce(T) -> SqlValue,
111) -> Result<SqlValue, OrmError>
112where
113    for<'a> T: tiberius::FromSql<'a>,
114{
115    let value = row
116        .try_get::<T, _>(index)
117        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
118
119    Ok(value.map(map).unwrap_or(SqlValue::Null))
120}
121
122fn read_string(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
123    let value = row
124        .try_get::<&str, _>(index)
125        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
126
127    Ok(value
128        .map(|value| SqlValue::String(value.to_owned()))
129        .unwrap_or(SqlValue::Null))
130}
131
132fn read_bytes(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
133    let value = row
134        .try_get::<&[u8], _>(index)
135        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
136
137    Ok(value
138        .map(|value| SqlValue::Bytes(value.to_vec()))
139        .unwrap_or(SqlValue::Null))
140}
141
142fn read_intn(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
143    if let Some(value) = row
144        .try_get::<i64, _>(index)
145        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
146    {
147        return Ok(SqlValue::I64(value));
148    }
149
150    if let Some(value) = row
151        .try_get::<i32, _>(index)
152        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
153    {
154        return Ok(SqlValue::I32(value));
155    }
156
157    if let Some(value) = row
158        .try_get::<i16, _>(index)
159        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
160    {
161        return Ok(SqlValue::I32(i32::from(value)));
162    }
163
164    if let Some(value) = row
165        .try_get::<u8, _>(index)
166        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
167    {
168        return Ok(SqlValue::I32(i32::from(value)));
169    }
170
171    Ok(SqlValue::Null)
172}
173
174#[cfg(test)]
175mod tests {
176    use super::{static_sql_value, unsupported_column_type_error};
177    use sql_orm_core::{OrmErrorKind, SqlValue};
178    use tiberius::ColumnType;
179
180    #[test]
181    fn reports_unsupported_sql_server_column_types() {
182        for column_type in [ColumnType::Xml, ColumnType::Udt, ColumnType::SSVariant] {
183            let error = unsupported_column_type_error(column_type).unwrap();
184            assert_eq!(
185                error.message(),
186                "unsupported SQL Server column type in MssqlRow"
187            );
188            assert_eq!(error.kind(), OrmErrorKind::Mapping);
189        }
190    }
191
192    #[test]
193    fn treats_sql_null_columns_as_sql_value_null() {
194        let value = static_sql_value(ColumnType::Null).unwrap();
195
196        assert_eq!(value, SqlValue::Null);
197    }
198}