1use crate::error::{TiberiusErrorContext, map_tiberius_error};
2use chrono::{NaiveDate, NaiveDateTime};
3use rust_decimal::Decimal;
4use sql_orm_core::{OrmError, Row as OrmRow, SqlValue};
5use tiberius::{ColumnType, Row};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Copy)]
9pub struct MssqlRow<'a> {
10 inner: &'a Row,
11}
12
13impl<'a> MssqlRow<'a> {
14 pub fn new(inner: &'a Row) -> Self {
15 Self { inner }
16 }
17
18 pub fn inner(&self) -> &'a Row {
19 self.inner
20 }
21}
22
23impl OrmRow for MssqlRow<'_> {
24 fn try_get(&self, column: &str) -> Result<Option<SqlValue>, OrmError> {
25 let Some((index, column_type)) =
26 self.inner
27 .columns()
28 .iter()
29 .enumerate()
30 .find_map(|(index, metadata)| {
31 (metadata.name() == column).then_some((index, metadata.column_type()))
32 })
33 else {
34 return Ok(None);
35 };
36
37 read_sql_value(self.inner, index, column_type).map(Some)
38 }
39}
40
41fn read_sql_value(row: &Row, index: usize, column_type: ColumnType) -> Result<SqlValue, OrmError> {
42 if let Some(value) = static_sql_value(column_type) {
43 return Ok(value);
44 }
45
46 if let Some(error) = unsupported_column_type_error(column_type) {
47 return Err(error);
48 }
49
50 match column_type {
51 ColumnType::Bit | ColumnType::Bitn => {
52 read_typed(row, index, |value: bool| SqlValue::Bool(value))
53 }
54 ColumnType::Int1 => read_typed(row, index, |value: u8| SqlValue::I32(i32::from(value))),
55 ColumnType::Int2 => read_typed(row, index, |value: i16| SqlValue::I32(i32::from(value))),
56 ColumnType::Int4 => read_typed(row, index, |value: i32| SqlValue::I32(value)),
57 ColumnType::Int8 => read_typed(row, index, |value: i64| SqlValue::I64(value)),
58 ColumnType::Intn => read_intn(row, index),
59 ColumnType::Float4 => read_typed(row, index, |value: f32| SqlValue::F64(f64::from(value))),
60 ColumnType::Float8 | ColumnType::Floatn | ColumnType::Money | ColumnType::Money4 => {
61 read_typed(row, index, |value: f64| SqlValue::F64(value))
62 }
63 ColumnType::Guid => read_typed(row, index, |value: Uuid| SqlValue::Uuid(value)),
64 ColumnType::Decimaln | ColumnType::Numericn => {
65 read_typed(row, index, |value: Decimal| SqlValue::Decimal(value))
66 }
67 ColumnType::Daten => read_typed(row, index, |value: NaiveDate| SqlValue::Date(value)),
68 ColumnType::Datetime
69 | ColumnType::Datetime4
70 | ColumnType::Datetimen
71 | ColumnType::Datetime2 => {
72 read_typed(row, index, |value: NaiveDateTime| SqlValue::DateTime(value))
73 }
74 ColumnType::BigVarChar
75 | ColumnType::BigChar
76 | ColumnType::NVarchar
77 | ColumnType::NChar
78 | ColumnType::Text
79 | ColumnType::NText => read_string(row, index),
80 ColumnType::BigVarBin | ColumnType::BigBinary | ColumnType::Image => read_bytes(row, index),
81 ColumnType::Null
82 | ColumnType::Timen
83 | ColumnType::DatetimeOffsetn
84 | ColumnType::Xml
85 | ColumnType::Udt
86 | ColumnType::SSVariant => {
87 unreachable!("special-case column type should have returned early")
88 }
89 }
90}
91
92fn static_sql_value(column_type: ColumnType) -> Option<SqlValue> {
93 match column_type {
94 ColumnType::Null => Some(SqlValue::Null),
95 _ => None,
96 }
97}
98
99fn unsupported_column_type_error(column_type: ColumnType) -> Option<OrmError> {
100 match column_type {
101 ColumnType::Timen
102 | ColumnType::DatetimeOffsetn
103 | ColumnType::Xml
104 | ColumnType::Udt
105 | ColumnType::SSVariant => Some(OrmError::new(
106 "unsupported SQL Server column type in MssqlRow",
107 )),
108 _ => None,
109 }
110}
111
112fn read_typed<T>(
113 row: &Row,
114 index: usize,
115 map: impl FnOnce(T) -> SqlValue,
116) -> Result<SqlValue, OrmError>
117where
118 for<'a> T: tiberius::FromSql<'a>,
119{
120 let value = row
121 .try_get::<T, _>(index)
122 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
123
124 Ok(value.map(map).unwrap_or(SqlValue::Null))
125}
126
127fn read_string(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
128 let value = row
129 .try_get::<&str, _>(index)
130 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
131
132 Ok(value
133 .map(|value| SqlValue::String(value.to_owned()))
134 .unwrap_or(SqlValue::Null))
135}
136
137fn read_bytes(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
138 let value = row
139 .try_get::<&[u8], _>(index)
140 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?;
141
142 Ok(value
143 .map(|value| SqlValue::Bytes(value.to_vec()))
144 .unwrap_or(SqlValue::Null))
145}
146
147fn read_intn(row: &Row, index: usize) -> Result<SqlValue, OrmError> {
148 if let Some(value) = row
149 .try_get::<i64, _>(index)
150 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
151 {
152 return Ok(SqlValue::I64(value));
153 }
154
155 if let Some(value) = row
156 .try_get::<i32, _>(index)
157 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
158 {
159 return Ok(SqlValue::I32(value));
160 }
161
162 if let Some(value) = row
163 .try_get::<i16, _>(index)
164 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
165 {
166 return Ok(SqlValue::I32(i32::from(value)));
167 }
168
169 if let Some(value) = row
170 .try_get::<u8, _>(index)
171 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ReadRowValue))?
172 {
173 return Ok(SqlValue::I32(i32::from(value)));
174 }
175
176 Ok(SqlValue::Null)
177}
178
179#[cfg(test)]
180mod tests {
181 use super::{static_sql_value, unsupported_column_type_error};
182 use sql_orm_core::SqlValue;
183 use tiberius::ColumnType;
184
185 #[test]
186 fn reports_unsupported_sql_server_column_types() {
187 for column_type in [
188 ColumnType::Timen,
189 ColumnType::DatetimeOffsetn,
190 ColumnType::Xml,
191 ColumnType::Udt,
192 ColumnType::SSVariant,
193 ] {
194 let error = unsupported_column_type_error(column_type).unwrap();
195 assert_eq!(
196 error.message(),
197 "unsupported SQL Server column type in MssqlRow"
198 );
199 }
200 }
201
202 #[test]
203 fn treats_sql_null_columns_as_sql_value_null() {
204 let value = static_sql_value(ColumnType::Null).unwrap();
205
206 assert_eq!(value, SqlValue::Null);
207 }
208}