sqlx_core_oldapi/odbc/connection/
mod.rs1use crate::common::StatementCache;
2use crate::connection::{Connection, LogSettings};
3use crate::error::Error;
4use crate::odbc::{
5 Odbc, OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcQueryResult,
6 OdbcRow, OdbcTypeInfo,
7};
8use crate::transaction::Transaction;
9use either::Either;
10use sqlx_rt::spawn_blocking;
11mod odbc_bridge;
12use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
13use futures_core::future::BoxFuture;
14use futures_util::future;
15use odbc_api::{
16 handles::StatementConnection, ConnectionTransitions, Prepared, ResultSetMetadata,
17 SharedConnection,
18};
19use odbc_bridge::{establish_connection, execute_sql};
20use std::borrow::Cow;
21use std::sync::{Arc, Mutex};
22
23mod executor;
24
25type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
26type SharedPreparedStatement = Arc<Mutex<PreparedStatement>>;
27
28struct CollectedColumns {
29 columns: Vec<OdbcColumn>,
30 deferred: bool,
31}
32
33fn collect_columns(
34 prepared: &mut PreparedStatement,
35 parameter_count: usize,
36 allow_deferred_result_columns: bool,
37) -> Result<CollectedColumns, Error> {
38 let count = match prepared.num_result_cols() {
39 Ok(count) => count,
40 Err(error) if allow_deferred_result_columns && parameter_count > 0 => {
41 log::debug!("ODBC prepare deferred result columns until execution: {error}");
42 validate_parameter_metadata(prepared, parameter_count)?;
43 return Ok(CollectedColumns {
44 columns: Vec::new(),
45 deferred: true,
46 });
47 }
48 Err(error) => return Err(error.into()),
49 };
50
51 let mut columns = Vec::with_capacity(count as usize);
52 for i in 1..=count {
53 columns.push(describe_column(prepared, i as u16)?);
54 }
55 Ok(CollectedColumns {
56 columns,
57 deferred: false,
58 })
59}
60
61fn validate_parameter_metadata(
62 prepared: &mut PreparedStatement,
63 parameter_count: usize,
64) -> Result<(), Error> {
65 for index in 1..=parameter_count {
66 let parameter_number = u16::try_from(index)
67 .map_err(|_| Error::Protocol(format!("ODBC parameter index {index} exceeds u16")))?;
68 prepared.describe_param(parameter_number)?;
69 }
70 Ok(())
71}
72
73fn collect_statement_metadata(
74 prepared: &mut PreparedStatement,
75 allow_deferred_result_columns: bool,
76) -> Result<(OdbcStatementMetadata, bool), Error> {
77 let parameters = usize::from(prepared.num_params()?);
78 let collected = collect_columns(prepared, parameters, allow_deferred_result_columns)?;
79 let metadata_complete = !(collected.deferred || parameters > 0 && collected.columns.is_empty());
80
81 Ok((
82 OdbcStatementMetadata {
83 columns: collected.columns,
84 parameters,
85 },
86 metadata_complete,
87 ))
88}
89
90pub(super) fn describe_column<S>(stmt: &mut S, index: u16) -> Result<OdbcColumn, Error>
91where
92 S: ResultSetMetadata,
93{
94 let mut cd = odbc_api::ColumnDescription::default();
95 stmt.describe_col(index, &mut cd)?;
96
97 Ok(OdbcColumn {
98 name: decode_column_name(cd.name, index),
99 type_info: OdbcTypeInfo::new(cd.data_type),
100 ordinal: usize::from(
101 index
102 .checked_sub(1)
103 .ok_or_else(|| Error::Protocol("ODBC column indices are 1-based".into()))?,
104 ),
105 })
106}
107
108pub(super) trait ColumnNameDecode {
109 fn decode_or_default(self, index: u16) -> String;
110}
111
112impl ColumnNameDecode for Vec<u8> {
113 fn decode_or_default(self, index: u16) -> String {
114 String::from_utf8(self).unwrap_or_else(|_| format!("col{}", index - 1))
115 }
116}
117
118impl ColumnNameDecode for Vec<u16> {
119 fn decode_or_default(self, index: u16) -> String {
120 String::from_utf16(&self).unwrap_or_else(|_| format!("col{}", index - 1))
121 }
122}
123
124pub(super) fn decode_column_name<T: ColumnNameDecode>(name: T, index: u16) -> String {
125 name.decode_or_default(index)
126}
127
128pub struct OdbcConnection {
133 pub(crate) conn: SharedConnection<'static>,
134 pub(crate) stmt_cache: StatementCache<SharedPreparedStatement>,
135 pub(crate) buffer_settings: OdbcBufferSettings,
136 pub(crate) log_settings: LogSettings,
137}
138
139impl std::fmt::Debug for OdbcConnection {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("OdbcConnection")
142 .field("conn", &self.conn)
143 .field("buffer_settings", &self.buffer_settings)
144 .finish()
145 }
146}
147
148impl OdbcConnection {
149 pub(crate) async fn with_conn<R, F, S>(&mut self, operation: S, f: F) -> Result<R, Error>
150 where
151 R: Send + 'static,
152 F: FnOnce(&mut odbc_api::Connection<'static>) -> Result<R, Error> + Send + 'static,
153 S: std::fmt::Display + Send + 'static,
154 {
155 let conn = Arc::clone(&self.conn);
156 spawn_blocking(move || {
157 let mut conn_guard = conn.lock().map_err(|_| {
158 Error::Protocol(format!("ODBC {}: failed to lock connection", operation))
159 })?;
160 f(&mut conn_guard)
161 })
162 .await
163 }
164
165 pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
166 let shared_conn = spawn_blocking({
167 let options = options.clone();
168 move || {
169 let conn = establish_connection(&options)?;
170 let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
171 Ok::<_, Error>(shared_conn)
172 }
173 })
174 .await?;
175
176 Ok(Self {
177 conn: shared_conn,
178 stmt_cache: StatementCache::new(options.statement_cache_capacity),
179 buffer_settings: options.buffer_settings,
180 log_settings: options.log_settings.clone(),
181 })
182 }
183
184 pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
185 self.with_conn("ping", move |conn| {
186 conn.execute("SELECT 1", (), None)?;
187 Ok(())
188 })
189 .await
190 }
191
192 pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
193 self.with_conn("begin", move |conn| {
194 conn.set_autocommit(false)?;
195 Ok(())
196 })
197 .await
198 }
199
200 pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
201 self.with_conn("commit", move |conn| {
202 conn.commit()?;
203 conn.set_autocommit(true)?;
204 Ok(())
205 })
206 .await
207 }
208
209 pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
210 self.with_conn("rollback", move |conn| {
211 conn.rollback()?;
212 conn.set_autocommit(true)?;
213 Ok(())
214 })
215 .await
216 }
217
218 pub(crate) fn execute_stream(
220 &mut self,
221 sql: &str,
222 args: Option<OdbcArguments>,
223 ) -> flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>> {
224 let (tx, rx) = flume::bounded(64);
225
226 let sql_owned = sql.to_string();
227 let maybe_prepared = if let Some(prepared) = self.stmt_cache.get_mut(sql) {
228 MaybePrepared::Prepared(Arc::clone(prepared))
229 } else {
230 MaybePrepared::NotPrepared(sql_owned.clone())
231 };
232
233 let conn = Arc::clone(&self.conn);
234 let buffer_settings = self.buffer_settings;
235 let log_settings = self.log_settings.clone();
236 sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || {
237 let mut logger = crate::logger::QueryLogger::new(&sql_owned, log_settings);
238 let result = conn
239 .lock()
240 .map_err(|_| Error::Protocol("ODBC execute: failed to lock connection".into()))
241 .and_then(|mut conn| {
242 execute_sql(
243 &mut conn,
244 maybe_prepared,
245 args,
246 &tx,
247 buffer_settings,
248 &mut logger,
249 )
250 });
251
252 if let Err(e) = result {
253 let _ = tx.send(Err(e));
254 }
255 }));
256
257 rx
258 }
259
260 pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
261 while self.stmt_cache.remove_lru().is_some() {}
262 Ok(())
263 }
264
265 async fn prepare_with_metadata_policy<'a>(
266 &mut self,
267 sql: &'a str,
268 store_to_cache: bool,
269 allow_deferred_result_columns: bool,
270 ) -> Result<OdbcStatement<'a>, Error> {
271 let sql_owned = sql.to_string();
272 let cached = self
273 .stmt_cache
274 .get_mut(sql)
275 .map(|prepared| Arc::clone(prepared));
276
277 if let Some(prepared) = cached {
278 let metadata = spawn_blocking(move || {
279 let mut prepared = prepared.lock().map_err(|_| {
280 Error::Protocol("ODBC prepare: failed to lock prepared statement".into())
281 })?;
282 collect_statement_metadata(&mut prepared, allow_deferred_result_columns)
283 .map(|(metadata, _)| metadata)
284 })
285 .await?;
286
287 return Ok(OdbcStatement {
288 sql: Cow::Borrowed(sql),
289 metadata,
290 });
291 }
292
293 let conn = Arc::clone(&self.conn);
294 let sql_clone = sql_owned.clone();
295 let (prepared, metadata, metadata_complete) = spawn_blocking(move || {
296 let mut prepared = conn.into_prepared(&sql_clone)?;
297 let metadata =
298 collect_statement_metadata(&mut prepared, allow_deferred_result_columns)?;
299 Ok::<_, Error>((prepared, metadata.0, metadata.1))
300 })
301 .await?;
302
303 if !allow_deferred_result_columns && !metadata_complete {
304 return Err(Error::Protocol(
305 "ODBC driver did not provide result-column metadata before execution".into(),
306 ));
307 }
308
309 if store_to_cache && metadata_complete && self.stmt_cache.is_enabled() {
310 self.stmt_cache
311 .insert(&sql_owned, Arc::new(Mutex::new(prepared)));
312 }
313
314 Ok(OdbcStatement {
315 sql: Cow::Borrowed(sql),
316 metadata,
317 })
318 }
319
320 pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
321 self.prepare_with_metadata_policy(sql, true, true).await
322 }
323
324 pub(crate) async fn describe_statement<'a>(
325 &mut self,
326 sql: &'a str,
327 ) -> Result<OdbcStatement<'a>, Error> {
328 self.prepare_with_metadata_policy(sql, false, false).await
329 }
330}
331
332pub(crate) enum MaybePrepared {
333 Prepared(SharedPreparedStatement),
334 NotPrepared(String),
335}
336
337impl std::fmt::Debug for MaybePrepared {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 match self {
340 MaybePrepared::Prepared(_) => f.debug_tuple("Prepared").finish(),
341 MaybePrepared::NotPrepared(sql) => f.debug_tuple("NotPrepared").field(sql).finish(),
342 }
343 }
344}
345
346impl Connection for OdbcConnection {
347 type Database = Odbc;
348
349 type Options = OdbcConnectOptions;
350
351 fn close(self) -> BoxFuture<'static, Result<(), Error>> {
352 Box::pin(async move {
353 drop(self);
355 Ok(())
356 })
357 }
358
359 fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
360 Box::pin(async move { Ok(()) })
361 }
362
363 fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
364 Box::pin(self.ping_blocking())
365 }
366
367 fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
368 where
369 Self: Sized,
370 {
371 Transaction::begin(self)
372 }
373
374 fn cached_statements_size(&self) -> usize {
375 self.stmt_cache.len()
376 }
377
378 #[doc(hidden)]
379 fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
380 Box::pin(future::ok(()))
381 }
382
383 #[doc(hidden)]
384 fn should_flush(&self) -> bool {
385 false
386 }
387
388 fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
389 Box::pin(self.clear_cached_statements())
390 }
391
392 fn dbms_name(&mut self) -> BoxFuture<'_, Result<String, Error>> {
393 Box::pin(async move {
394 self.with_conn("dbms_name", move |conn| {
395 Ok(conn.database_management_system_name()?)
396 })
397 .await
398 })
399 }
400}