1use crate::error::{TiberiusErrorContext, map_tiberius_error};
2use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
3use rust_decimal::Decimal;
4use sql_orm_core::{OrmError, Row as OrmRow, SqlValue};
5use tiberius::{ColumnType, Row};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy)]
9pub struct MssqlRow<'a> {
10 inner: &'a Row,
11}
12
13impl<'a> MssqlRow<'a> {
14 pub fn new(inner: &'a Row) -> Self {
15 Self { inner }
16 }
17
18 pub fn inner(&self) -> &'a Row {
19 self.inner
20 }
21}
22
23impl OrmRow for MssqlRow<'_> {
24 fn try_get(&self, column: &str) -> Result<Option<SqlValue>, OrmError> {
25 let Some((index, column_type)) =
26 self.inner
27 .columns()
28 .iter()
29 .enumerate()
30 .find_map(|(index, metadata)| {
31 (metadata.name() == column).then_some((index, metadata.column_type()))
32 })
33 else {
34 return Ok(None);
35 };
36
37 read_sql_value(self.inner, index, column_type).map(Some)
38 }
39}
40
41fn read_sql_value(row: &Row, index: usize, column_type: ColumnType) -> Result<SqlValue, OrmError> {
42 if let Some(value) = static_sql_value(column_type) {
43 return Ok(value);
44 }
45
46 if let Some(error) = unsupported_column_type_error(column_type) {
47 return Err(error);
48 }
49
50 match column_type {
51 ColumnType::Bit | ColumnType::Bitn => {
52 read_typed(row, index, |value: bool| SqlValue::Bool(value))
53 }
54 ColumnType::Int1 => read_typed(row, index, |value: u8| SqlValue::I32(i32::from(value))),
55 ColumnType::Int2 => read_typed(row, index, |value: i16| SqlValue::I32(i32::from(value))),
56 ColumnType::Int4 => read_typed(row, index, |value: i32| SqlValue::I32(value)),
57 ColumnType::Int8 => read_typed(row, index, |value: i64| SqlValue::I64(value)),
58 ColumnType::Intn => read_intn(row, index),
59 ColumnType::Float4 => read_typed(row, index, |value: f32| SqlValue::F64(f64::from(value))),
60 ColumnType::Float8 | ColumnType::Floatn | ColumnType::Money | ColumnType::Money4 => {
61 read_typed(row, index, |value: f64| SqlValue::F64(value))
62 }
63 ColumnType::Guid => read_typed(row, index, |value: Uuid| SqlValue::Uuid(value)),
64 ColumnType::Decimaln | ColumnType::Numericn => {
65 read_typed(row, index, |value: Decimal| SqlValue::Decimal(value))
66 }
67 ColumnType::Daten => read_typed(row, index, |value: NaiveDate| SqlValue::Date(value)),
68 ColumnType::Timen => read_typed(row, index, |value: NaiveTime| SqlValue::Time(value)),
69 ColumnType::Datetime
70 | ColumnType::Datetime4
71 | ColumnType::Datetimen
72 | ColumnType::Datetime2 => {
73 read_typed(row, index, |value: NaiveDateTime| SqlValue::DateTime(value))
74 }
75 ColumnType::DatetimeOffsetn => read_typed(row, index, |value: DateTime<FixedOffset>| {
76 SqlValue::DateTimeOffset(value)
77 }),
78 ColumnType::BigVarChar
79 | ColumnType::BigChar
80 | ColumnType::NVarchar
81 | ColumnType::NChar
82 | ColumnType::Text
83 | ColumnType::NText => read_string(row, index),
84 ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => read_bytes(row, index),
85 ColumnType::Null | ColumnType::Xml | ColumnType::Udt | ColumnType::SSVariant => {
86 unreachable!("special-case column type should have returned early")
87 }
88 }
89}
90
91fn static_sql_value(column_type: ColumnType) -> Option<SqlValue> {
92 match column_type {
93 ColumnType::Null => Some(SqlValue::Null),
94 _ => None,
95 }
96}
97
98fn unsupported_column_type_error(column_type: ColumnType) -> Option<OrmError> {
99 match column_type {
100 ColumnType::Xml | ColumnType::Udt | ColumnType::SSVariant => Some(OrmError::mapping(
101 "unsupported SQL Server column type in MssqlRow",
102 )),
103 _ => None,
104 }
105}
106
107fn read_typed<T>(
108 row: &Row,
109 index: usize,
110 map: impl FnOnce(T) -> SqlValue,
111) -> Result<SqlValue, OrmError>
112where
113 for<'a> T: tiberius::FromSql<'a>,
114{
115 let value = row
116 .try_get::<T, _>(index)
117 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
118
119 Ok(value.map(map).unwrap_or(SqlValue::Null))
120}
121
122fn read_string(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
123 let value = row
124 .try_get::<&str, _>(index)
125 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
126
127 Ok(value
128 .map(|value| SqlValue::String(value.to_owned()))
129 .unwrap_or(SqlValue::Null))
130}
131
132fn read_bytes(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
133 let value = row
134 .try_get::<&[u8], _>(index)
135 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
136
137 Ok(value
138 .map(|value| SqlValue::Bytes(value.to_vec()))
139 .unwrap_or(SqlValue::Null))
140}
141
142fn read_intn(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
143 if let Some(value) = row
144 .try_get::<i64, _>(index)
145 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
146 {
147 return Ok(SqlValue::I64(value));
148 }
149
150 if let Some(value) = row
151 .try_get::<i32, _>(index)
152 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
153 {
154 return Ok(SqlValue::I32(value));
155 }
156
157 if let Some(value) = row
158 .try_get::<i16, _>(index)
159 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
160 {
161 return Ok(SqlValue::I32(i32::from(value)));
162 }
163
164 if let Some(value) = row
165 .try_get::<u8, _>(index)
166 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
167 {
168 return Ok(SqlValue::I32(i32::from(value)));
169 }
170
171 Ok(SqlValue::Null)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::{static_sql_value, unsupported_column_type_error};
177 use sql_orm_core::{OrmErrorKind, SqlValue};
178 use tiberius::ColumnType;
179
180 #[test]
181 fn reports_unsupported_sql_server_column_types() {
182 for column_type in [ColumnType::Xml, ColumnType::Udt, ColumnType::SSVariant] {
183 let error = unsupported_column_type_error(column_type).unwrap();
184 assert_eq!(
185 error.message(),
186 "unsupported SQL Server column type in MssqlRow"
187 );
188 assert_eq!(error.kind(), OrmErrorKind::Mapping);
189 }
190 }
191
192 #[test]
193 fn treats_sql_null_columns_as_sql_value_null() {
194 let value = static_sql_value(ColumnType::Null).unwrap();
195
196 assert_eq!(value, SqlValue::Null);
197 }
198}