Skip to main content

sqlx_sqlserver/
any.rs

1//! Runtime `Any` driver registration for SQL Server.
2//!
3//! The driver can be installed with SQLx `Any`. The current SQL Server port supports SQL batch
4//! execution and the same stable scalar RPC argument types as the native connection.
5
6use crate::{
7    Mssql, MssqlArguments, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult,
8    MssqlTransactionManager, MssqlType, MssqlTypeInfo,
9};
10use futures_core::future::BoxFuture;
11use futures_core::stream::BoxStream;
12use futures_util::{future, stream, FutureExt, StreamExt};
13use sqlx_core::any::driver::AnyDriver;
14use sqlx_core::any::{
15    AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
16    AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind,
17};
18use sqlx_core::arguments::Arguments;
19use sqlx_core::column::Column;
20use sqlx_core::connection::{ConnectOptions, Connection};
21use sqlx_core::database::Database;
22use sqlx_core::ext::ustr::UStr;
23use sqlx_core::row::Row;
24use sqlx_core::sql_str::SqlStr;
25use sqlx_core::statement::Statement;
26use sqlx_core::transaction::TransactionManager;
27use sqlx_core::{Either, Error, HashMap};
28use std::sync::Arc;
29
30/// Installable SQL Server driver for SQLx `Any` connections.
31pub const DRIVER: AnyDriver = AnyDriver::with_migrate::<Mssql>();
32
33impl AnyConnectionBackend for MssqlConnection {
34    fn name(&self) -> &str {
35        <Mssql as Database>::NAME
36    }
37
38    fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
39        Connection::close(*self).boxed()
40    }
41
42    fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
43        Connection::close_hard(*self).boxed()
44    }
45
46    fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
47        Connection::ping(self).boxed()
48    }
49
50    fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
51        MssqlTransactionManager::begin(self, statement).boxed()
52    }
53
54    fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
55        MssqlTransactionManager::commit(self).boxed()
56    }
57
58    fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
59        MssqlTransactionManager::rollback(self).boxed()
60    }
61
62    fn start_rollback(&mut self) {
63        MssqlTransactionManager::start_rollback(self);
64    }
65
66    fn get_transaction_depth(&self) -> usize {
67        MssqlTransactionManager::get_transaction_depth(self)
68    }
69
70    fn shrink_buffers(&mut self) {
71        Connection::shrink_buffers(self);
72    }
73
74    fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
75        Connection::flush(self).boxed()
76    }
77
78    fn should_flush(&self) -> bool {
79        Connection::should_flush(self)
80    }
81
82    #[cfg(feature = "migrate")]
83    fn as_migrate(
84        &mut self,
85    ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> {
86        Ok(self)
87    }
88
89    fn fetch_many(
90        &mut self,
91        query: SqlStr,
92        _persistent: bool,
93        arguments: Option<AnyArguments>,
94    ) -> BoxStream<'_, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
95        stream::once(async move {
96            let native_arguments = arguments
97                .map(convert_any_arguments)
98                .transpose()
99                .map_err(Error::Encode)?;
100            self.run_execute_sql(query.as_str(), native_arguments.as_ref())
101                .await
102        })
103        .map(|result| match result {
104            Ok(output) => {
105                let column_names = column_names(&output.columns);
106                let rows = output.rows.into_iter().map(move |row| {
107                    AnyRow::map_from(&row, Arc::clone(&column_names)).map(Either::Right)
108                });
109                let done = std::iter::once(Ok(Either::Left(map_result(output.result))));
110                stream::iter(rows.chain(done)).boxed()
111            }
112            Err(error) => stream::once(future::ready(Err(error))).boxed(),
113        })
114        .flatten()
115        .boxed()
116    }
117
118    fn fetch_optional(
119        &mut self,
120        query: SqlStr,
121        _persistent: bool,
122        arguments: Option<AnyArguments>,
123    ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
124        Box::pin(async move {
125            let native_arguments = arguments
126                .map(convert_any_arguments)
127                .transpose()
128                .map_err(Error::Encode)?;
129            self.run_execute_sql(query.as_str(), native_arguments.as_ref())
130                .await?
131                .rows
132                .into_iter()
133                .next()
134                .map(|row| {
135                    let column_names = column_names(row.columns());
136                    AnyRow::map_from(&row, column_names)
137                })
138                .transpose()
139        })
140    }
141
142    fn prepare_with<'c, 'q: 'c>(
143        &'c mut self,
144        sql: SqlStr,
145        parameters: &[AnyTypeInfo],
146    ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
147        let parameters = parameters
148            .iter()
149            .map(mssql_type_from_any)
150            .collect::<Result<Vec<_>, _>>();
151
152        Box::pin(async move {
153            let parameters = parameters?;
154            let statement = self.run_prepare(sql.as_str(), &parameters).await?;
155            let statement = crate::MssqlStatement::with_parameters(
156                sql,
157                statement.columns,
158                if parameters.is_empty() {
159                    None
160                } else {
161                    Some(Either::Left(parameters))
162                },
163            );
164            let column_names = column_names(statement.columns());
165            AnyStatement::try_from_statement(statement, column_names)
166        })
167    }
168}
169
170fn mssql_type_from_any(type_info: &AnyTypeInfo) -> Result<MssqlTypeInfo, Error> {
171    Ok(match type_info.kind() {
172        AnyTypeInfoKind::Bool => MssqlTypeInfo::BIT,
173        AnyTypeInfoKind::SmallInt => MssqlTypeInfo::SMALLINT,
174        AnyTypeInfoKind::Integer | AnyTypeInfoKind::Null => MssqlTypeInfo::INT,
175        AnyTypeInfoKind::BigInt => MssqlTypeInfo::BIGINT,
176        AnyTypeInfoKind::Real => MssqlTypeInfo::REAL,
177        AnyTypeInfoKind::Double => MssqlTypeInfo::FLOAT,
178        AnyTypeInfoKind::Text => MssqlTypeInfo::NVARCHAR,
179        AnyTypeInfoKind::Blob => MssqlTypeInfo::VARBINARY,
180    })
181}
182
183fn convert_any_arguments(
184    arguments: AnyArguments,
185) -> Result<MssqlArguments, sqlx_core::error::BoxDynError> {
186    let mut out = MssqlArguments::default();
187
188    for argument in arguments.values.0 {
189        match argument {
190            AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::<i32>::None),
191            AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::<bool>::None),
192            AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::<i16>::None),
193            AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::<i32>::None),
194            AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::<i64>::None),
195            AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::<f32>::None),
196            AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::<f64>::None),
197            AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::<String>::None),
198            AnyValueKind::Null(AnyTypeInfoKind::Blob) => out.add(Option::<Vec<u8>>::None),
199            AnyValueKind::Bool(value) => out.add(value),
200            AnyValueKind::SmallInt(value) => out.add(value),
201            AnyValueKind::Integer(value) => out.add(value),
202            AnyValueKind::BigInt(value) => out.add(value),
203            AnyValueKind::Real(value) => out.add(value),
204            AnyValueKind::Double(value) => out.add(value),
205            AnyValueKind::Text(value) => out.add(value.as_str()),
206            AnyValueKind::TextSlice(value) => out.add(value.as_ref()),
207            AnyValueKind::Blob(value) => out.add(value.as_slice()),
208            other => {
209                return Err(format!(
210                    "SQL Server Any driver does not support argument value {other:?}"
211                )
212                .into());
213            }
214        }?;
215    }
216
217    Ok(out)
218}
219
220fn map_result(result: MssqlQueryResult) -> AnyQueryResult {
221    AnyQueryResult {
222        rows_affected: result.rows_affected(),
223        last_insert_id: None,
224    }
225}
226
227fn column_names(columns: &[MssqlColumn]) -> Arc<HashMap<UStr, usize>> {
228    Arc::new(
229        columns
230            .iter()
231            .map(|column| (UStr::new(column.name()), column.ordinal()))
232            .collect(),
233    )
234}
235
236impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions {
237    type Error = Error;
238
239    fn try_from(options: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
240        MssqlConnectOptions::from_url(&options.database_url)
241    }
242}
243
244impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo {
245    type Error = Error;
246
247    fn try_from(type_info: &'a MssqlTypeInfo) -> Result<Self, Self::Error> {
248        let kind = match type_info.kind() {
249            MssqlType::Null => AnyTypeInfoKind::Null,
250            MssqlType::Bit => AnyTypeInfoKind::Bool,
251            MssqlType::SmallInt => AnyTypeInfoKind::SmallInt,
252            MssqlType::Int => AnyTypeInfoKind::Integer,
253            MssqlType::BigInt => AnyTypeInfoKind::BigInt,
254            MssqlType::Real => AnyTypeInfoKind::Real,
255            MssqlType::Float => AnyTypeInfoKind::Double,
256            MssqlType::NVarChar | MssqlType::VarChar => AnyTypeInfoKind::Text,
257            MssqlType::VarBinary => AnyTypeInfoKind::Blob,
258            MssqlType::TinyInt
259            | MssqlType::Decimal
260            | MssqlType::Money
261            | MssqlType::Date
262            | MssqlType::Time
263            | MssqlType::DateTime
264            | MssqlType::DateTime2
265            | MssqlType::DateTimeOffset
266            | MssqlType::UniqueIdentifier
267            | MssqlType::Other(_) => {
268                return Err(Error::AnyDriverError(
269                    format!("Any driver does not support the SQL Server type {type_info:?}").into(),
270                ));
271            }
272        };
273
274        Ok(AnyTypeInfo { kind })
275    }
276}
277
278impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn {
279    type Error = Error;
280
281    fn try_from(column: &'a MssqlColumn) -> Result<Self, Self::Error> {
282        let type_info =
283            AnyTypeInfo::try_from(column.type_info()).map_err(|error| Error::ColumnDecode {
284                index: column.name().to_owned(),
285                source: error.into(),
286            })?;
287
288        Ok(Self {
289            ordinal: column.ordinal(),
290            name: UStr::new(column.name()),
291            type_info,
292        })
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn maps_stable_sql_server_types_to_any_types() {
302        assert_eq!(
303            AnyTypeInfo::try_from(&MssqlTypeInfo::BIT).unwrap().kind(),
304            AnyTypeInfoKind::Bool
305        );
306        assert_eq!(
307            AnyTypeInfo::try_from(&MssqlTypeInfo::INT).unwrap().kind(),
308            AnyTypeInfoKind::Integer
309        );
310        assert_eq!(
311            AnyTypeInfo::try_from(&MssqlTypeInfo::NVARCHAR)
312                .unwrap()
313                .kind(),
314            AnyTypeInfoKind::Text
315        );
316        assert_eq!(
317            AnyTypeInfo::try_from(&MssqlTypeInfo::VARCHAR)
318                .unwrap()
319                .kind(),
320            AnyTypeInfoKind::Text
321        );
322        assert_eq!(
323            AnyTypeInfo::try_from(&MssqlTypeInfo::VARBINARY)
324                .unwrap()
325                .kind(),
326            AnyTypeInfoKind::Blob
327        );
328    }
329
330    #[test]
331    fn rejects_unstable_sql_server_types_for_any_mapping() {
332        assert!(matches!(
333            AnyTypeInfo::try_from(&MssqlTypeInfo::TINYINT),
334            Err(Error::AnyDriverError(_))
335        ));
336    }
337}