1use crate::{
7 Mssql, MssqlArguments, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult,
8 MssqlTransactionManager, MssqlType, MssqlTypeInfo,
9};
10use futures_core::future::BoxFuture;
11use futures_core::stream::BoxStream;
12use futures_util::{future, stream, FutureExt, StreamExt};
13use sqlx_core::any::driver::AnyDriver;
14use sqlx_core::any::{
15 AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
16 AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind,
17};
18use sqlx_core::arguments::Arguments;
19use sqlx_core::column::Column;
20use sqlx_core::connection::{ConnectOptions, Connection};
21use sqlx_core::database::Database;
22use sqlx_core::ext::ustr::UStr;
23use sqlx_core::row::Row;
24use sqlx_core::sql_str::SqlStr;
25use sqlx_core::statement::Statement;
26use sqlx_core::transaction::TransactionManager;
27use sqlx_core::{Either, Error, HashMap};
28use std::sync::Arc;
29
30pub const DRIVER: AnyDriver = AnyDriver::with_migrate::<Mssql>();
32
33impl AnyConnectionBackend for MssqlConnection {
34 fn name(&self) -> &str {
35 <Mssql as Database>::NAME
36 }
37
38 fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
39 Connection::close(*self).boxed()
40 }
41
42 fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
43 Connection::close_hard(*self).boxed()
44 }
45
46 fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
47 Connection::ping(self).boxed()
48 }
49
50 fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
51 MssqlTransactionManager::begin(self, statement).boxed()
52 }
53
54 fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
55 MssqlTransactionManager::commit(self).boxed()
56 }
57
58 fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
59 MssqlTransactionManager::rollback(self).boxed()
60 }
61
62 fn start_rollback(&mut self) {
63 MssqlTransactionManager::start_rollback(self);
64 }
65
66 fn get_transaction_depth(&self) -> usize {
67 MssqlTransactionManager::get_transaction_depth(self)
68 }
69
70 fn shrink_buffers(&mut self) {
71 Connection::shrink_buffers(self);
72 }
73
74 fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
75 Connection::flush(self).boxed()
76 }
77
78 fn should_flush(&self) -> bool {
79 Connection::should_flush(self)
80 }
81
82 #[cfg(feature = "migrate")]
83 fn as_migrate(
84 &mut self,
85 ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> {
86 Ok(self)
87 }
88
89 fn fetch_many(
90 &mut self,
91 query: SqlStr,
92 _persistent: bool,
93 arguments: Option<AnyArguments>,
94 ) -> BoxStream<'_, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
95 stream::once(async move {
96 let native_arguments = arguments
97 .map(convert_any_arguments)
98 .transpose()
99 .map_err(Error::Encode)?;
100 self.run_execute_sql(query.as_str(), native_arguments.as_ref())
101 .await
102 })
103 .map(|result| match result {
104 Ok(output) => {
105 let column_names = column_names(&output.columns);
106 let rows = output.rows.into_iter().map(move |row| {
107 AnyRow::map_from(&row, Arc::clone(&column_names)).map(Either::Right)
108 });
109 let done = std::iter::once(Ok(Either::Left(map_result(output.result))));
110 stream::iter(rows.chain(done)).boxed()
111 }
112 Err(error) => stream::once(future::ready(Err(error))).boxed(),
113 })
114 .flatten()
115 .boxed()
116 }
117
118 fn fetch_optional(
119 &mut self,
120 query: SqlStr,
121 _persistent: bool,
122 arguments: Option<AnyArguments>,
123 ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
124 Box::pin(async move {
125 let native_arguments = arguments
126 .map(convert_any_arguments)
127 .transpose()
128 .map_err(Error::Encode)?;
129 self.run_execute_sql(query.as_str(), native_arguments.as_ref())
130 .await?
131 .rows
132 .into_iter()
133 .next()
134 .map(|row| {
135 let column_names = column_names(row.columns());
136 AnyRow::map_from(&row, column_names)
137 })
138 .transpose()
139 })
140 }
141
142 fn prepare_with<'c, 'q: 'c>(
143 &'c mut self,
144 sql: SqlStr,
145 parameters: &[AnyTypeInfo],
146 ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
147 let parameters = parameters
148 .iter()
149 .map(mssql_type_from_any)
150 .collect::<Result<Vec<_>, _>>();
151
152 Box::pin(async move {
153 let parameters = parameters?;
154 let statement = self.run_prepare(sql.as_str(), ¶meters).await?;
155 let statement = crate::MssqlStatement::with_parameters(
156 sql,
157 statement.columns,
158 if parameters.is_empty() {
159 None
160 } else {
161 Some(Either::Left(parameters))
162 },
163 );
164 let column_names = column_names(statement.columns());
165 AnyStatement::try_from_statement(statement, column_names)
166 })
167 }
168}
169
170fn mssql_type_from_any(type_info: &AnyTypeInfo) -> Result<MssqlTypeInfo, Error> {
171 Ok(match type_info.kind() {
172 AnyTypeInfoKind::Bool => MssqlTypeInfo::BIT,
173 AnyTypeInfoKind::SmallInt => MssqlTypeInfo::SMALLINT,
174 AnyTypeInfoKind::Integer | AnyTypeInfoKind::Null => MssqlTypeInfo::INT,
175 AnyTypeInfoKind::BigInt => MssqlTypeInfo::BIGINT,
176 AnyTypeInfoKind::Real => MssqlTypeInfo::REAL,
177 AnyTypeInfoKind::Double => MssqlTypeInfo::FLOAT,
178 AnyTypeInfoKind::Text => MssqlTypeInfo::NVARCHAR,
179 AnyTypeInfoKind::Blob => MssqlTypeInfo::VARBINARY,
180 })
181}
182
183fn convert_any_arguments(
184 arguments: AnyArguments,
185) -> Result<MssqlArguments, sqlx_core::error::BoxDynError> {
186 let mut out = MssqlArguments::default();
187
188 for argument in arguments.values.0 {
189 match argument {
190 AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::<i32>::None),
191 AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::<bool>::None),
192 AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::<i16>::None),
193 AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::<i32>::None),
194 AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::<i64>::None),
195 AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::<f32>::None),
196 AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::<f64>::None),
197 AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::<String>::None),
198 AnyValueKind::Null(AnyTypeInfoKind::Blob) => out.add(Option::<Vec<u8>>::None),
199 AnyValueKind::Bool(value) => out.add(value),
200 AnyValueKind::SmallInt(value) => out.add(value),
201 AnyValueKind::Integer(value) => out.add(value),
202 AnyValueKind::BigInt(value) => out.add(value),
203 AnyValueKind::Real(value) => out.add(value),
204 AnyValueKind::Double(value) => out.add(value),
205 AnyValueKind::Text(value) => out.add(value.as_str()),
206 AnyValueKind::TextSlice(value) => out.add(value.as_ref()),
207 AnyValueKind::Blob(value) => out.add(value.as_slice()),
208 other => {
209 return Err(format!(
210 "SQL Server Any driver does not support argument value {other:?}"
211 )
212 .into());
213 }
214 }?;
215 }
216
217 Ok(out)
218}
219
220fn map_result(result: MssqlQueryResult) -> AnyQueryResult {
221 AnyQueryResult {
222 rows_affected: result.rows_affected(),
223 last_insert_id: None,
224 }
225}
226
227fn column_names(columns: &[MssqlColumn]) -> Arc<HashMap<UStr, usize>> {
228 Arc::new(
229 columns
230 .iter()
231 .map(|column| (UStr::new(column.name()), column.ordinal()))
232 .collect(),
233 )
234}
235
236impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions {
237 type Error = Error;
238
239 fn try_from(options: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
240 MssqlConnectOptions::from_url(&options.database_url)
241 }
242}
243
244impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo {
245 type Error = Error;
246
247 fn try_from(type_info: &'a MssqlTypeInfo) -> Result<Self, Self::Error> {
248 let kind = match type_info.kind() {
249 MssqlType::Null => AnyTypeInfoKind::Null,
250 MssqlType::Bit => AnyTypeInfoKind::Bool,
251 MssqlType::SmallInt => AnyTypeInfoKind::SmallInt,
252 MssqlType::Int => AnyTypeInfoKind::Integer,
253 MssqlType::BigInt => AnyTypeInfoKind::BigInt,
254 MssqlType::Real => AnyTypeInfoKind::Real,
255 MssqlType::Float => AnyTypeInfoKind::Double,
256 MssqlType::NVarChar | MssqlType::VarChar => AnyTypeInfoKind::Text,
257 MssqlType::VarBinary => AnyTypeInfoKind::Blob,
258 MssqlType::TinyInt | MssqlType::Other(_) => {
259 return Err(Error::AnyDriverError(
260 format!("Any driver does not support the SQL Server type {type_info:?}").into(),
261 ));
262 }
263 };
264
265 Ok(AnyTypeInfo { kind })
266 }
267}
268
269impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn {
270 type Error = Error;
271
272 fn try_from(column: &'a MssqlColumn) -> Result<Self, Self::Error> {
273 let type_info =
274 AnyTypeInfo::try_from(column.type_info()).map_err(|error| Error::ColumnDecode {
275 index: column.name().to_owned(),
276 source: error.into(),
277 })?;
278
279 Ok(Self {
280 ordinal: column.ordinal(),
281 name: UStr::new(column.name()),
282 type_info,
283 })
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn maps_stable_sql_server_types_to_any_types() {
293 assert_eq!(
294 AnyTypeInfo::try_from(&MssqlTypeInfo::BIT).unwrap().kind(),
295 AnyTypeInfoKind::Bool
296 );
297 assert_eq!(
298 AnyTypeInfo::try_from(&MssqlTypeInfo::INT).unwrap().kind(),
299 AnyTypeInfoKind::Integer
300 );
301 assert_eq!(
302 AnyTypeInfo::try_from(&MssqlTypeInfo::NVARCHAR)
303 .unwrap()
304 .kind(),
305 AnyTypeInfoKind::Text
306 );
307 assert_eq!(
308 AnyTypeInfo::try_from(&MssqlTypeInfo::VARCHAR)
309 .unwrap()
310 .kind(),
311 AnyTypeInfoKind::Text
312 );
313 assert_eq!(
314 AnyTypeInfo::try_from(&MssqlTypeInfo::VARBINARY)
315 .unwrap()
316 .kind(),
317 AnyTypeInfoKind::Blob
318 );
319 }
320
321 #[test]
322 fn rejects_unstable_sql_server_types_for_any_mapping() {
323 assert!(matches!(
324 AnyTypeInfo::try_from(&MssqlTypeInfo::TINYINT),
325 Err(Error::AnyDriverError(_))
326 ));
327 }
328}