Skip to main content

sqlx_mssql_odbc_core/
any.rs

1//! `Any` driver integration for MSSQL via ODBC.
2//!
3//! Implements [`AnyConnectionBackend`] for [`MssqlConnection`] and the required
4//! type conversions so that `sqlx-cli` and `AnyConnection` can work with
5//! `mssql://` URLs.
6
7use crate::{
8    Mssql, MssqlArguments, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult,
9    MssqlRow, MssqlTransactionManager, MssqlTypeInfo,
10};
11use futures_core::future::BoxFuture;
12use futures_core::stream::BoxStream;
13use futures_util::{future, stream, FutureExt, StreamExt};
14use odbc_api::DataType;
15use sqlx_core::any::{
16    AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
17    AnyStatement, AnyTypeInfo, AnyTypeInfoKind,
18};
19use sqlx_core::column::Column;
20use sqlx_core::connection::Connection;
21use sqlx_core::database::Database;
22use sqlx_core::encode::{Encode, IsNull};
23use sqlx_core::error::BoxDynError;
24use sqlx_core::executor::Executor;
25use sqlx_core::ext::ustr::UStr;
26use sqlx_core::row::Row;
27use sqlx_core::sql_str::SqlStr;
28use sqlx_core::statement::Statement;
29use sqlx_core::transaction::TransactionManager;
30use sqlx_core::HashMap;
31use std::str::FromStr;
32use std::sync::Arc;
33
34sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Mssql);
35
36// ---------------------------------------------------------------------------
37// Additional Encode impl needed by AnyArguments::convert_into
38//
39// The upstream `impl_encode_for_smartpointer!(Arc<T>)` macro generates
40// `Arc<T>: Encode<DB>` only when `T: Encode<DB>`.  Since `str: Encode<Mssql>`
41// is not implemented (only `&str` is), we must provide `Arc<str>: Encode` manually.
42// ---------------------------------------------------------------------------
43
44impl<'q> Encode<'q, Mssql> for Arc<str> {
45    fn encode(
46        self,
47        buf: &mut Vec<crate::MssqlArgumentValue>,
48    ) -> Result<IsNull, BoxDynError> {
49        buf.push(crate::MssqlArgumentValue::Text(self.to_string()));
50        Ok(IsNull::No)
51    }
52
53    fn encode_by_ref(
54        &self,
55        buf: &mut Vec<crate::MssqlArgumentValue>,
56    ) -> Result<IsNull, BoxDynError> {
57        buf.push(crate::MssqlArgumentValue::Text(self.to_string()));
58        Ok(IsNull::No)
59    }
60}
61
62// ---------------------------------------------------------------------------
63// AnyConnectionBackend
64// ---------------------------------------------------------------------------
65
66impl AnyConnectionBackend for MssqlConnection {
67    fn name(&self) -> &str {
68        <Mssql as Database>::NAME
69    }
70
71    fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
72        Connection::close(*self).boxed()
73    }
74
75    fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
76        Connection::close_hard(*self).boxed()
77    }
78
79    fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
80        Connection::ping(self).boxed()
81    }
82
83    fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
84        MssqlTransactionManager::begin(self, statement).boxed()
85    }
86
87    fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
88        MssqlTransactionManager::commit(self).boxed()
89    }
90
91    fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
92        MssqlTransactionManager::rollback(self).boxed()
93    }
94
95    fn start_rollback(&mut self) {
96        MssqlTransactionManager::start_rollback(self)
97    }
98
99    fn get_transaction_depth(&self) -> usize {
100        MssqlTransactionManager::get_transaction_depth(self)
101    }
102
103    fn shrink_buffers(&mut self) {
104        Connection::shrink_buffers(self);
105    }
106
107    fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
108        Connection::flush(self).boxed()
109    }
110
111    fn should_flush(&self) -> bool {
112        Connection::should_flush(self)
113    }
114
115    fn cached_statements_size(&self) -> usize {
116        Connection::cached_statements_size(self)
117    }
118
119    fn clear_cached_statements(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
120        Connection::clear_cached_statements(self).boxed()
121    }
122
123    #[cfg(feature = "migrate")]
124    fn as_migrate(
125        &mut self,
126    ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> {
127        Ok(self)
128    }
129
130    fn fetch_many(
131        &mut self,
132        query: SqlStr,
133        persistent: bool,
134        arguments: Option<AnyArguments>,
135    ) -> BoxStream<'_, sqlx_core::Result<sqlx_core::Either<AnyQueryResult, AnyRow>>> {
136        let persistent = persistent && arguments.is_some();
137
138        let arguments: Option<MssqlArguments> = match arguments
139            .map(|a| a.convert_into::<MssqlArguments>())
140            .transpose()
141        {
142            Ok(args) => args,
143            Err(error) => {
144                return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed()
145            }
146        };
147
148        let rx = self.execute_receiver(query, persistent, arguments);
149        receiver_to_any_stream(rx)
150    }
151
152    fn fetch_optional(
153        &mut self,
154        query: SqlStr,
155        persistent: bool,
156        arguments: Option<AnyArguments>,
157    ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
158        let persistent = persistent && arguments.is_some();
159
160        let arguments: Option<MssqlArguments> = match arguments
161            .map(|a| a.convert_into::<MssqlArguments>())
162            .transpose()
163        {
164            Ok(args) => args,
165            Err(error) => return Box::pin(future::ready(Err(sqlx_core::Error::Encode(error)))),
166        };
167
168        let rx = self.execute_receiver(query, persistent, arguments);
169        Box::pin(async move {
170            while let Ok(item) = rx.recv_async().await {
171                match item? {
172                    sqlx_core::Either::Right(row) => return Ok(Some(AnyRow::try_from(&row)?)),
173                    sqlx_core::Either::Left(_) => {}
174                }
175            }
176            Ok(None)
177        })
178    }
179
180    fn prepare_with<'c, 'q: 'c>(
181        &'c mut self,
182        sql: SqlStr,
183        _parameters: &[AnyTypeInfo],
184    ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
185        Box::pin(async move {
186            let statement = Executor::prepare_with(self, sql, &[]).await?;
187            // Clone column names into owned Strings for UStr conversion
188            let columns: Vec<MssqlColumn> = statement.columns().to_vec();
189            let mut names = HashMap::<UStr, usize>::new();
190            for (i, col) in columns.iter().enumerate() {
191                names.insert(UStr::from(col.name().to_owned()), i);
192            }
193            let column_names = Arc::new(names);
194            AnyStatement::try_from_statement(statement, column_names)
195        })
196    }
197
198    #[cfg(feature = "offline")]
199    fn describe(
200        &mut self,
201        sql: SqlStr,
202    ) -> BoxFuture<
203        '_,
204        sqlx_core::Result<sqlx_core::describe::Describe<sqlx_core::any::Any>>,
205    > {
206        Box::pin(async move {
207            let describe = Executor::describe(self, sql).await?;
208            describe.try_into_any()
209        })
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Type conversions
215// ---------------------------------------------------------------------------
216
217impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo {
218    type Error = sqlx_core::Error;
219
220    fn try_from(type_info: &'a MssqlTypeInfo) -> Result<Self, Self::Error> {
221        let kind = match type_info.data_type() {
222            DataType::Bit => AnyTypeInfoKind::Bool,
223            DataType::TinyInt | DataType::SmallInt => AnyTypeInfoKind::SmallInt,
224            DataType::Integer => AnyTypeInfoKind::Integer,
225            DataType::BigInt => AnyTypeInfoKind::BigInt,
226            DataType::Real => AnyTypeInfoKind::Real,
227            DataType::Float { .. } | DataType::Double => AnyTypeInfoKind::Double,
228            // Text types
229            DataType::Char { .. }
230            | DataType::Varchar { .. }
231            | DataType::LongVarchar { .. }
232            | DataType::WChar { .. }
233            | DataType::WVarchar { .. }
234            | DataType::WLongVarchar { .. } => AnyTypeInfoKind::Text,
235            // Binary types
236            DataType::Binary { .. }
237            | DataType::Varbinary { .. }
238            | DataType::LongVarbinary { .. } => AnyTypeInfoKind::Blob,
239            // Date/time types — no dedicated AnyTypeInfoKind, fall back to Text
240            DataType::Date | DataType::Time { .. } | DataType::Timestamp { .. } => {
241                AnyTypeInfoKind::Text
242            }
243            // Decimal / Numeric — no dedicated AnyTypeInfoKind, fall back to Text
244            DataType::Decimal { .. } | DataType::Numeric { .. } => AnyTypeInfoKind::Text,
245            // Other (GUID, Unknown, Null indicator, etc.) — fall back to Text
246            DataType::Other { .. } | DataType::Unknown => AnyTypeInfoKind::Text,
247        };
248
249        Ok(AnyTypeInfo { kind })
250    }
251}
252
253impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn {
254    type Error = sqlx_core::Error;
255
256    fn try_from(column: &'a MssqlColumn) -> Result<Self, Self::Error> {
257        let type_info = AnyTypeInfo::try_from(column.type_info())?;
258
259        Ok(AnyColumn {
260            ordinal: column.ordinal(),
261            // Clone the &str to an owned String for UStr conversion
262            name: UStr::from(column.name().to_owned()),
263            type_info,
264        })
265    }
266}
267
268impl<'a> TryFrom<&'a MssqlRow> for AnyRow {
269    type Error = sqlx_core::Error;
270
271    fn try_from(row: &'a MssqlRow) -> Result<Self, Self::Error> {
272        // Clone column names into owned Strings for Arc<HashMap<UStr, usize>>
273        let columns: Vec<MssqlColumn> = row.columns().to_vec();
274        let mut names = HashMap::<UStr, usize>::new();
275        for (i, col) in columns.iter().enumerate() {
276            names.insert(UStr::from(col.name().to_owned()), i);
277        }
278        let column_names = Arc::new(names);
279        AnyRow::map_from(row, column_names)
280    }
281}
282
283impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions {
284    type Error = sqlx_core::Error;
285
286    fn try_from(any_opts: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
287        // Use FromStr to parse the database URL into MssqlConnectOptions
288        let mut opts: MssqlConnectOptions =
289            FromStr::from_str(any_opts.database_url.as_str())?;
290        opts.log_statements = any_opts.log_settings.statements_level;
291        opts.log_slow_statements = any_opts.log_settings.slow_statements_level;
292        opts.log_slow_statement_duration = any_opts.log_settings.slow_statements_duration;
293        Ok(opts)
294    }
295}
296
297// ---------------------------------------------------------------------------
298// Helper: convert an ExecuteResult stream to an AnyResult stream
299// ---------------------------------------------------------------------------
300
301fn receiver_to_any_stream(
302    rx: flume::Receiver<
303        sqlx_core::Result<sqlx_core::Either<MssqlQueryResult, MssqlRow>>,
304    >,
305) -> BoxStream<'static, sqlx_core::Result<sqlx_core::Either<AnyQueryResult, AnyRow>>> {
306    stream::unfold(rx, |rx| async move {
307        rx.recv_async().await.ok().map(|item| {
308            let mapped = match item {
309                Ok(sqlx_core::Either::Left(result)) => {
310                    Ok(sqlx_core::Either::Left(map_result(result)))
311                }
312                Ok(sqlx_core::Either::Right(row)) => {
313                    AnyRow::try_from(&row).map(sqlx_core::Either::Right)
314                }
315                Err(err) => Err(err),
316            };
317            (mapped, rx)
318        })
319    })
320    .boxed()
321}
322
323fn map_result(result: MssqlQueryResult) -> AnyQueryResult {
324    AnyQueryResult {
325        rows_affected: result.rows_affected(),
326        last_insert_id: None,
327    }
328}