Skip to main content

sql_orm_tiberius/
row.rs

1use crate::error::{TiberiusErrorContext, map_tiberius_error};
2use chrono::{NaiveDate, NaiveDateTime};
3use rust_decimal::Decimal;
4use sql_orm_core::{OrmError, Row as OrmRow, SqlValue};
5use tiberius::{ColumnType, Row};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy)]
9pub struct MssqlRow<'a> {
10    inner: &'a Row,
11}
12
13impl<'a> MssqlRow<'a> {
14    pub fn new(inner: &'a Row) -> Self {
15        Self { inner }
16    }
17
18    pub fn inner(&self) -> &'a Row {
19        self.inner
20    }
21}
22
23impl OrmRow for MssqlRow<'_> {
24    fn try_get(&self, column: &str) -> Result<Option<SqlValue>, OrmError> {
25        let Some((index, column_type)) =
26            self.inner
27                .columns()
28                .iter()
29                .enumerate()
30                .find_map(|(index, metadata)| {
31                    (metadata.name() == column).then_some((index, metadata.column_type()))
32                })
33        else {
34            return Ok(None);
35        };
36
37        read_sql_value(self.inner, index, column_type).map(Some)
38    }
39}
40
41fn read_sql_value(row: &Row, index: usize, column_type: ColumnType) -> Result<SqlValue, OrmError> {
42    if let Some(value) = static_sql_value(column_type) {
43        return Ok(value);
44    }
45
46    if let Some(error) = unsupported_column_type_error(column_type) {
47        return Err(error);
48    }
49
50    match column_type {
51        ColumnType::Bit | ColumnType::Bitn => {
52            read_typed(row, index, |value: bool| SqlValue::Bool(value))
53        }
54        ColumnType::Int1 => read_typed(row, index, |value: u8| SqlValue::I32(i32::from(value))),
55        ColumnType::Int2 => read_typed(row, index, |value: i16| SqlValue::I32(i32::from(value))),
56        ColumnType::Int4 => read_typed(row, index, |value: i32| SqlValue::I32(value)),
57        ColumnType::Int8 => read_typed(row, index, |value: i64| SqlValue::I64(value)),
58        ColumnType::Intn => read_intn(row, index),
59        ColumnType::Float4 => read_typed(row, index, |value: f32| SqlValue::F64(f64::from(value))),
60        ColumnType::Float8 | ColumnType::Floatn | ColumnType::Money | ColumnType::Money4 => {
61            read_typed(row, index, |value: f64| SqlValue::F64(value))
62        }
63        ColumnType::Guid => read_typed(row, index, |value: Uuid| SqlValue::Uuid(value)),
64        ColumnType::Decimaln | ColumnType::Numericn => {
65            read_typed(row, index, |value: Decimal| SqlValue::Decimal(value))
66        }
67        ColumnType::Daten => read_typed(row, index, |value: NaiveDate| SqlValue::Date(value)),
68        ColumnType::Datetime
69        | ColumnType::Datetime4
70        | ColumnType::Datetimen
71        | ColumnType::Datetime2 => {
72            read_typed(row, index, |value: NaiveDateTime| SqlValue::DateTime(value))
73        }
74        ColumnType::BigVarChar
75        | ColumnType::BigChar
76        | ColumnType::NVarchar
77        | ColumnType::NChar
78        | ColumnType::Text
79        | ColumnType::NText => read_string(row, index),
80        ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => read_bytes(row, index),
81        ColumnType::Null
82        | ColumnType::Timen
83        | ColumnType::DatetimeOffsetn
84        | ColumnType::Xml
85        | ColumnType::Udt
86        | ColumnType::SSVariant => {
87            unreachable!("special-case column type should have returned early")
88        }
89    }
90}
91
92fn static_sql_value(column_type: ColumnType) -> Option<SqlValue> {
93    match column_type {
94        ColumnType::Null => Some(SqlValue::Null),
95        _ => None,
96    }
97}
98
99fn unsupported_column_type_error(column_type: ColumnType) -> Option<OrmError> {
100    match column_type {
101        ColumnType::Timen
102        | ColumnType::DatetimeOffsetn
103        | ColumnType::Xml
104        | ColumnType::Udt
105        | ColumnType::SSVariant => Some(OrmError::new(
106            "unsupported SQL Server column type in MssqlRow",
107        )),
108        _ => None,
109    }
110}
111
112fn read_typed<T>(
113    row: &Row,
114    index: usize,
115    map: impl FnOnce(T) -> SqlValue,
116) -> Result<SqlValue, OrmError>
117where
118    for<'a> T: tiberius::FromSql<'a>,
119{
120    let value = row
121        .try_get::<T, _>(index)
122        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
123
124    Ok(value.map(map).unwrap_or(SqlValue::Null))
125}
126
127fn read_string(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
128    let value = row
129        .try_get::<&str, _>(index)
130        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
131
132    Ok(value
133        .map(|value| SqlValue::String(value.to_owned()))
134        .unwrap_or(SqlValue::Null))
135}
136
137fn read_bytes(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
138    let value = row
139        .try_get::<&[u8], _>(index)
140        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
141
142    Ok(value
143        .map(|value| SqlValue::Bytes(value.to_vec()))
144        .unwrap_or(SqlValue::Null))
145}
146
147fn read_intn(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
148    if let Some(value) = row
149        .try_get::<i64, _>(index)
150        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
151    {
152        return Ok(SqlValue::I64(value));
153    }
154
155    if let Some(value) = row
156        .try_get::<i32, _>(index)
157        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
158    {
159        return Ok(SqlValue::I32(value));
160    }
161
162    if let Some(value) = row
163        .try_get::<i16, _>(index)
164        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
165    {
166        return Ok(SqlValue::I32(i32::from(value)));
167    }
168
169    if let Some(value) = row
170        .try_get::<u8, _>(index)
171        .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
172    {
173        return Ok(SqlValue::I32(i32::from(value)));
174    }
175
176    Ok(SqlValue::Null)
177}
178
179#[cfg(test)]
180mod tests {
181    use super::{static_sql_value, unsupported_column_type_error};
182    use sql_orm_core::SqlValue;
183    use tiberius::ColumnType;
184
185    #[test]
186    fn reports_unsupported_sql_server_column_types() {
187        for column_type in [
188            ColumnType::Timen,
189            ColumnType::DatetimeOffsetn,
190            ColumnType::Xml,
191            ColumnType::Udt,
192            ColumnType::SSVariant,
193        ] {
194            let error = unsupported_column_type_error(column_type).unwrap();
195            assert_eq!(
196                error.message(),
197                "unsupported SQL Server column type in MssqlRow"
198            );
199        }
200    }
201
202    #[test]
203    fn treats_sql_null_columns_as_sql_value_null() {
204        let value = static_sql_value(ColumnType::Null).unwrap();
205
206        assert_eq!(value, SqlValue::Null);
207    }
208}