1use crate::{
4 connection::OdbcExecution, DataTypeExt, Odbc, OdbcArgumentValue, OdbcArguments, OdbcColumn,
5 OdbcConnectOptions, OdbcConnection, OdbcQueryResult, OdbcTransactionManager, OdbcTypeInfo,
6};
7use futures_core::future::BoxFuture;
8use futures_core::stream::BoxStream;
9use futures_util::{future, stream, FutureExt, StreamExt};
10use sqlx_core::any::driver::AnyDriver;
11use sqlx_core::any::{
12 AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
13 AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind,
14};
15use sqlx_core::column::Column;
16use sqlx_core::connection::{ConnectOptions, Connection};
17use sqlx_core::database::Database;
18use sqlx_core::ext::ustr::UStr;
19use sqlx_core::row::Row;
20use sqlx_core::sql_str::SqlStr;
21use sqlx_core::statement::Statement;
22use sqlx_core::transaction::TransactionManager;
23use sqlx_core::{Either, HashMap};
24use std::sync::Arc;
25
26pub const DRIVER: AnyDriver = AnyDriver::without_migrate::<Odbc>();
28
29impl AnyConnectionBackend for OdbcConnection {
30 fn name(&self) -> &str {
31 <Odbc as Database>::NAME
32 }
33
34 fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
35 Connection::close(*self).boxed()
36 }
37
38 fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
39 Connection::close_hard(*self).boxed()
40 }
41
42 fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
43 Connection::ping(self).boxed()
44 }
45
46 fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
47 OdbcTransactionManager::begin(self, statement).boxed()
48 }
49
50 fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
51 OdbcTransactionManager::commit(self).boxed()
52 }
53
54 fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
55 OdbcTransactionManager::rollback(self).boxed()
56 }
57
58 fn start_rollback(&mut self) {
59 OdbcTransactionManager::start_rollback(self);
60 }
61
62 fn get_transaction_depth(&self) -> usize {
63 OdbcTransactionManager::get_transaction_depth(self)
64 }
65
66 fn shrink_buffers(&mut self) {
67 Connection::shrink_buffers(self);
68 }
69
70 fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
71 Connection::flush(self).boxed()
72 }
73
74 fn should_flush(&self) -> bool {
75 Connection::should_flush(self)
76 }
77
78 fn fetch_many(
79 &mut self,
80 query: SqlStr,
81 _persistent: bool,
82 arguments: Option<AnyArguments>,
83 ) -> BoxStream<'_, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
84 let arguments = arguments.map(map_arguments).transpose();
85
86 stream::once(async move {
87 let arguments = arguments?;
88 self.run_blocking_sql(query.as_str(), arguments.as_ref())
89 })
90 .map(|result| match result {
91 Ok(OdbcExecution::Done(result)) => {
92 stream::once(future::ready(Ok(Either::Left(map_result(result))))).boxed()
93 }
94 Ok(OdbcExecution::Rows(rows)) => {
95 if rows.is_empty() {
96 stream::once(future::ready(Ok(Either::Left(map_result(
97 OdbcQueryResult::new(0),
98 )))))
99 .boxed()
100 } else {
101 let column_names =
102 column_names(rows.first().expect("rows is not empty").columns());
103 let rows = rows.into_iter().map(move |row| {
104 AnyRow::map_from(&row, Arc::clone(&column_names)).map(Either::Right)
105 });
106 let done =
107 std::iter::once(Ok(Either::Left(map_result(OdbcQueryResult::new(0)))));
108 stream::iter(rows.chain(done)).boxed()
109 }
110 }
111 Err(error) => stream::once(future::ready(Err(error))).boxed(),
112 })
113 .flatten()
114 .boxed()
115 }
116
117 fn fetch_optional(
118 &mut self,
119 query: SqlStr,
120 _persistent: bool,
121 arguments: Option<AnyArguments>,
122 ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
123 let arguments = arguments.map(map_arguments).transpose();
124
125 Box::pin(async move {
126 let arguments = arguments?;
127 match self.run_blocking_sql(query.as_str(), arguments.as_ref())? {
128 OdbcExecution::Done(_) => Ok(None),
129 OdbcExecution::Rows(rows) => rows
130 .into_iter()
131 .next()
132 .map(|row| {
133 let column_names = column_names(row.columns());
134 AnyRow::map_from(&row, column_names)
135 })
136 .transpose(),
137 }
138 })
139 }
140
141 fn prepare_with<'c, 'q: 'c>(
142 &'c mut self,
143 sql: SqlStr,
144 _parameters: &[AnyTypeInfo],
145 ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
146 Box::pin(async move {
147 let statement = self.prepare_blocking(sql)?;
148 let column_names = column_names(statement.columns());
149 AnyStatement::try_from_statement(statement, column_names)
150 })
151 }
152}
153
154impl<'a> TryFrom<&'a AnyConnectOptions> for OdbcConnectOptions {
155 type Error = sqlx_core::Error;
156
157 fn try_from(options: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
158 let mut options_out = OdbcConnectOptions::from_url(&options.database_url)?;
159 options_out.log_statements = options.log_settings.statements_level;
160 options_out.log_slow_statements = options.log_settings.slow_statements_level;
161 options_out.log_slow_statement_duration = options.log_settings.slow_statements_duration;
162 Ok(options_out)
163 }
164}
165
166impl<'a> TryFrom<&'a OdbcTypeInfo> for AnyTypeInfo {
167 type Error = sqlx_core::Error;
168
169 fn try_from(type_info: &'a OdbcTypeInfo) -> Result<Self, Self::Error> {
170 let kind = match type_info.data_type() {
171 odbc_api::DataType::Unknown => AnyTypeInfoKind::Null,
172 odbc_api::DataType::Bit => AnyTypeInfoKind::Bool,
173 odbc_api::DataType::TinyInt | odbc_api::DataType::SmallInt => AnyTypeInfoKind::SmallInt,
174 odbc_api::DataType::Integer => AnyTypeInfoKind::Integer,
175 odbc_api::DataType::BigInt => AnyTypeInfoKind::BigInt,
176 odbc_api::DataType::Real => AnyTypeInfoKind::Real,
177 odbc_api::DataType::Float { .. } | odbc_api::DataType::Double => {
178 AnyTypeInfoKind::Double
179 }
180 data_type if data_type.accepts_character_data() => AnyTypeInfoKind::Text,
181 data_type if data_type.accepts_binary_data() => AnyTypeInfoKind::Blob,
182 data_type => {
183 return Err(sqlx_core::Error::AnyDriverError(
184 format!(
185 "ODBC Any conversion does not support result column type {data_type:?}"
186 )
187 .into(),
188 ));
189 }
190 };
191
192 Ok(AnyTypeInfo { kind })
193 }
194}
195
196impl<'a> TryFrom<&'a OdbcColumn> for AnyColumn {
197 type Error = sqlx_core::Error;
198
199 fn try_from(column: &'a OdbcColumn) -> Result<Self, Self::Error> {
200 let type_info = AnyTypeInfo::try_from(column.type_info()).map_err(|error| {
201 sqlx_core::Error::ColumnDecode {
202 index: column.name().to_owned(),
203 source: error.into(),
204 }
205 })?;
206
207 Ok(Self {
208 ordinal: column.ordinal(),
209 name: UStr::new(column.name()),
210 type_info,
211 })
212 }
213}
214
215fn map_arguments(arguments: AnyArguments) -> sqlx_core::Result<OdbcArguments> {
216 let mut out = OdbcArguments::default();
217
218 for value in arguments.values.0 {
219 out.add_value(match value {
220 AnyValueKind::Null(kind) => OdbcArgumentValue::Null(any_type_to_odbc(kind)),
221 AnyValueKind::Bool(value) => OdbcArgumentValue::Bit(value),
222 AnyValueKind::SmallInt(value) => OdbcArgumentValue::Int(i64::from(value)),
223 AnyValueKind::Integer(value) => OdbcArgumentValue::Int(i64::from(value)),
224 AnyValueKind::BigInt(value) => OdbcArgumentValue::Int(value),
225 AnyValueKind::Real(value) => OdbcArgumentValue::Float(f64::from(value)),
226 AnyValueKind::Double(value) => OdbcArgumentValue::Float(value),
227 AnyValueKind::Text(value) => OdbcArgumentValue::Text(value.to_string()),
228 AnyValueKind::TextSlice(value) => OdbcArgumentValue::Text(value.to_string()),
229 AnyValueKind::Blob(value) => OdbcArgumentValue::Bytes(value.to_vec()),
230 other => {
231 return Err(sqlx_core::Error::AnyDriverError(
232 format!("ODBC Any arguments do not support value kind {other:?}").into(),
233 ))
234 }
235 });
236 }
237
238 Ok(out)
239}
240
241fn any_type_to_odbc(kind: AnyTypeInfoKind) -> OdbcTypeInfo {
242 OdbcTypeInfo::new(match kind {
243 AnyTypeInfoKind::Null => odbc_api::DataType::Unknown,
244 AnyTypeInfoKind::Bool => odbc_api::DataType::Bit,
245 AnyTypeInfoKind::SmallInt => odbc_api::DataType::SmallInt,
246 AnyTypeInfoKind::Integer => odbc_api::DataType::Integer,
247 AnyTypeInfoKind::BigInt => odbc_api::DataType::BigInt,
248 AnyTypeInfoKind::Real => odbc_api::DataType::Real,
249 AnyTypeInfoKind::Double => odbc_api::DataType::Double,
250 AnyTypeInfoKind::Text => odbc_api::DataType::WVarchar { length: None },
251 AnyTypeInfoKind::Blob => odbc_api::DataType::Varbinary { length: None },
252 })
253}
254
255fn map_result(result: OdbcQueryResult) -> AnyQueryResult {
256 AnyQueryResult {
257 rows_affected: result.rows_affected(),
258 last_insert_id: None,
259 }
260}
261
262fn column_names(columns: &[OdbcColumn]) -> Arc<HashMap<UStr, usize>> {
263 Arc::new(
264 columns
265 .iter()
266 .map(|column| (UStr::new(column.name()), column.ordinal()))
267 .collect(),
268 )
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn maps_stable_odbc_types_to_any_types() {
277 assert_eq!(
278 AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Bit))
279 .unwrap()
280 .kind(),
281 AnyTypeInfoKind::Bool
282 );
283 assert_eq!(
284 AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Integer))
285 .unwrap()
286 .kind(),
287 AnyTypeInfoKind::Integer
288 );
289 assert_eq!(
290 AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::WVarchar {
291 length: None
292 }))
293 .unwrap()
294 .kind(),
295 AnyTypeInfoKind::Text
296 );
297 assert_eq!(
298 AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Varbinary {
299 length: None
300 }))
301 .unwrap()
302 .kind(),
303 AnyTypeInfoKind::Blob
304 );
305 }
306
307 #[test]
308 fn rejects_unstable_odbc_types_for_any_mapping() {
309 assert!(matches!(
310 AnyTypeInfo::try_from(&OdbcTypeInfo::new(odbc_api::DataType::Timestamp {
311 precision: 6
312 })),
313 Err(sqlx_core::Error::AnyDriverError(_))
314 ));
315 }
316}