Skip to main content

sqlx_core_oldapi/odbc/connection/
mod.rs

1use crate::common::StatementCache;
2use crate::connection::{Connection, LogSettings};
3use crate::error::Error;
4use crate::odbc::{
5    Odbc, OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcQueryResult,
6    OdbcRow, OdbcTypeInfo,
7};
8use crate::transaction::Transaction;
9use either::Either;
10use sqlx_rt::spawn_blocking;
11mod odbc_bridge;
12use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
13use futures_core::future::BoxFuture;
14use futures_util::future;
15use odbc_api::{
16    handles::StatementConnection, ConnectionTransitions, Prepared, ResultSetMetadata,
17    SharedConnection,
18};
19use odbc_bridge::{establish_connection, execute_sql};
20use std::borrow::Cow;
21use std::sync::{Arc, Mutex};
22
23mod executor;
24
25type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
26type SharedPreparedStatement = Arc<Mutex<PreparedStatement>>;
27
28struct CollectedColumns {
29    columns: Vec<OdbcColumn>,
30    deferred: bool,
31}
32
33fn collect_columns(
34    prepared: &mut PreparedStatement,
35    parameter_count: usize,
36    allow_deferred_result_columns: bool,
37) -> Result<CollectedColumns, Error> {
38    let count = match prepared.num_result_cols() {
39        Ok(count) => count,
40        Err(error) if allow_deferred_result_columns && parameter_count > 0 => {
41            log::debug!("ODBC prepare deferred result columns until execution: {error}");
42            validate_parameter_metadata(prepared, parameter_count)?;
43            return Ok(CollectedColumns {
44                columns: Vec::new(),
45                deferred: true,
46            });
47        }
48        Err(error) => return Err(error.into()),
49    };
50
51    let mut columns = Vec::with_capacity(count as usize);
52    for i in 1..=count {
53        columns.push(describe_column(prepared, i as u16)?);
54    }
55    Ok(CollectedColumns {
56        columns,
57        deferred: false,
58    })
59}
60
61fn validate_parameter_metadata(
62    prepared: &mut PreparedStatement,
63    parameter_count: usize,
64) -> Result<(), Error> {
65    for index in 1..=parameter_count {
66        let parameter_number = u16::try_from(index)
67            .map_err(|_| Error::Protocol(format!("ODBC parameter index {index} exceeds u16")))?;
68        prepared.describe_param(parameter_number)?;
69    }
70    Ok(())
71}
72
73fn collect_statement_metadata(
74    prepared: &mut PreparedStatement,
75    allow_deferred_result_columns: bool,
76) -> Result<(OdbcStatementMetadata, bool), Error> {
77    let parameters = usize::from(prepared.num_params()?);
78    let collected = collect_columns(prepared, parameters, allow_deferred_result_columns)?;
79    let metadata_complete = !(collected.deferred || parameters > 0 && collected.columns.is_empty());
80
81    Ok((
82        OdbcStatementMetadata {
83            columns: collected.columns,
84            parameters,
85        },
86        metadata_complete,
87    ))
88}
89
90pub(super) fn describe_column<S>(stmt: &mut S, index: u16) -> Result<OdbcColumn, Error>
91where
92    S: ResultSetMetadata,
93{
94    let mut cd = odbc_api::ColumnDescription::default();
95    stmt.describe_col(index, &mut cd)?;
96
97    Ok(OdbcColumn {
98        name: decode_column_name(cd.name, index),
99        type_info: OdbcTypeInfo::new(cd.data_type),
100        ordinal: usize::from(
101            index
102                .checked_sub(1)
103                .ok_or_else(|| Error::Protocol("ODBC column indices are 1-based".into()))?,
104        ),
105    })
106}
107
108pub(super) trait ColumnNameDecode {
109    fn decode_or_default(self, index: u16) -> String;
110}
111
112impl ColumnNameDecode for Vec<u8> {
113    fn decode_or_default(self, index: u16) -> String {
114        String::from_utf8(self).unwrap_or_else(|_| format!("col{}", index - 1))
115    }
116}
117
118impl ColumnNameDecode for Vec<u16> {
119    fn decode_or_default(self, index: u16) -> String {
120        String::from_utf16(&self).unwrap_or_else(|_| format!("col{}", index - 1))
121    }
122}
123
124pub(super) fn decode_column_name<T: ColumnNameDecode>(name: T, index: u16) -> String {
125    name.decode_or_default(index)
126}
127
128/// A connection to an ODBC-accessible database.
129///
130/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking
131/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
132pub struct OdbcConnection {
133    pub(crate) conn: SharedConnection<'static>,
134    pub(crate) stmt_cache: StatementCache<SharedPreparedStatement>,
135    pub(crate) buffer_settings: OdbcBufferSettings,
136    pub(crate) log_settings: LogSettings,
137}
138
139impl std::fmt::Debug for OdbcConnection {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("OdbcConnection")
142            .field("conn", &self.conn)
143            .field("buffer_settings", &self.buffer_settings)
144            .finish()
145    }
146}
147
148impl OdbcConnection {
149    pub(crate) async fn with_conn<R, F, S>(&mut self, operation: S, f: F) -> Result<R, Error>
150    where
151        R: Send + 'static,
152        F: FnOnce(&mut odbc_api::Connection<'static>) -> Result<R, Error> + Send + 'static,
153        S: std::fmt::Display + Send + 'static,
154    {
155        let conn = Arc::clone(&self.conn);
156        spawn_blocking(move || {
157            let mut conn_guard = conn.lock().map_err(|_| {
158                Error::Protocol(format!("ODBC {}: failed to lock connection", operation))
159            })?;
160            f(&mut conn_guard)
161        })
162        .await
163    }
164
165    pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
166        let shared_conn = spawn_blocking({
167            let options = options.clone();
168            move || {
169                let conn = establish_connection(&options)?;
170                let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
171                Ok::<_, Error>(shared_conn)
172            }
173        })
174        .await?;
175
176        Ok(Self {
177            conn: shared_conn,
178            stmt_cache: StatementCache::new(options.statement_cache_capacity),
179            buffer_settings: options.buffer_settings,
180            log_settings: options.log_settings.clone(),
181        })
182    }
183
184    pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
185        self.with_conn("ping", move |conn| {
186            conn.execute("SELECT 1", (), None)?;
187            Ok(())
188        })
189        .await
190    }
191
192    pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
193        self.with_conn("begin", move |conn| {
194            conn.set_autocommit(false)?;
195            Ok(())
196        })
197        .await
198    }
199
200    pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
201        self.with_conn("commit", move |conn| {
202            conn.commit()?;
203            conn.set_autocommit(true)?;
204            Ok(())
205        })
206        .await
207    }
208
209    pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
210        self.with_conn("rollback", move |conn| {
211            conn.rollback()?;
212            conn.set_autocommit(true)?;
213            Ok(())
214        })
215        .await
216    }
217
218    /// Launches a background task to execute the SQL statement and send the results to the returned channel.
219    pub(crate) fn execute_stream(
220        &mut self,
221        sql: &str,
222        args: Option<OdbcArguments>,
223    ) -> flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>> {
224        let (tx, rx) = flume::bounded(64);
225
226        let sql_owned = sql.to_string();
227        let maybe_prepared = if let Some(prepared) = self.stmt_cache.get_mut(sql) {
228            MaybePrepared::Prepared(Arc::clone(prepared))
229        } else {
230            MaybePrepared::NotPrepared(sql_owned.clone())
231        };
232
233        let conn = Arc::clone(&self.conn);
234        let buffer_settings = self.buffer_settings;
235        let log_settings = self.log_settings.clone();
236        sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || {
237            let mut logger = crate::logger::QueryLogger::new(&sql_owned, log_settings);
238            let result = conn
239                .lock()
240                .map_err(|_| Error::Protocol("ODBC execute: failed to lock connection".into()))
241                .and_then(|mut conn| {
242                    execute_sql(
243                        &mut conn,
244                        maybe_prepared,
245                        args,
246                        &tx,
247                        buffer_settings,
248                        &mut logger,
249                    )
250                });
251
252            if let Err(e) = result {
253                let _ = tx.send(Err(e));
254            }
255        }));
256
257        rx
258    }
259
260    pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
261        while self.stmt_cache.remove_lru().is_some() {}
262        Ok(())
263    }
264
265    async fn prepare_with_metadata_policy<'a>(
266        &mut self,
267        sql: &'a str,
268        store_to_cache: bool,
269        allow_deferred_result_columns: bool,
270    ) -> Result<OdbcStatement<'a>, Error> {
271        let sql_owned = sql.to_string();
272        let cached = self
273            .stmt_cache
274            .get_mut(sql)
275            .map(|prepared| Arc::clone(prepared));
276
277        if let Some(prepared) = cached {
278            let metadata = spawn_blocking(move || {
279                let mut prepared = prepared.lock().map_err(|_| {
280                    Error::Protocol("ODBC prepare: failed to lock prepared statement".into())
281                })?;
282                collect_statement_metadata(&mut prepared, allow_deferred_result_columns)
283                    .map(|(metadata, _)| metadata)
284            })
285            .await?;
286
287            return Ok(OdbcStatement {
288                sql: Cow::Borrowed(sql),
289                metadata,
290            });
291        }
292
293        let conn = Arc::clone(&self.conn);
294        let sql_clone = sql_owned.clone();
295        let (prepared, metadata, metadata_complete) = spawn_blocking(move || {
296            let mut prepared = conn.into_prepared(&sql_clone)?;
297            let metadata =
298                collect_statement_metadata(&mut prepared, allow_deferred_result_columns)?;
299            Ok::<_, Error>((prepared, metadata.0, metadata.1))
300        })
301        .await?;
302
303        if !allow_deferred_result_columns && !metadata_complete {
304            return Err(Error::Protocol(
305                "ODBC driver did not provide result-column metadata before execution".into(),
306            ));
307        }
308
309        if store_to_cache && metadata_complete && self.stmt_cache.is_enabled() {
310            self.stmt_cache
311                .insert(&sql_owned, Arc::new(Mutex::new(prepared)));
312        }
313
314        Ok(OdbcStatement {
315            sql: Cow::Borrowed(sql),
316            metadata,
317        })
318    }
319
320    pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
321        self.prepare_with_metadata_policy(sql, true, true).await
322    }
323
324    pub(crate) async fn describe_statement<'a>(
325        &mut self,
326        sql: &'a str,
327    ) -> Result<OdbcStatement<'a>, Error> {
328        self.prepare_with_metadata_policy(sql, false, false).await
329    }
330}
331
332pub(crate) enum MaybePrepared {
333    Prepared(SharedPreparedStatement),
334    NotPrepared(String),
335}
336
337impl std::fmt::Debug for MaybePrepared {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        match self {
340            MaybePrepared::Prepared(_) => f.debug_tuple("Prepared").finish(),
341            MaybePrepared::NotPrepared(sql) => f.debug_tuple("NotPrepared").field(sql).finish(),
342        }
343    }
344}
345
346impl Connection for OdbcConnection {
347    type Database = Odbc;
348
349    type Options = OdbcConnectOptions;
350
351    fn close(self) -> BoxFuture<'static, Result<(), Error>> {
352        Box::pin(async move {
353            // Drop connection by moving Arc and letting it fall out of scope.
354            drop(self);
355            Ok(())
356        })
357    }
358
359    fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
360        Box::pin(async move { Ok(()) })
361    }
362
363    fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
364        Box::pin(self.ping_blocking())
365    }
366
367    fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
368    where
369        Self: Sized,
370    {
371        Transaction::begin(self)
372    }
373
374    fn cached_statements_size(&self) -> usize {
375        self.stmt_cache.len()
376    }
377
378    #[doc(hidden)]
379    fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
380        Box::pin(future::ok(()))
381    }
382
383    #[doc(hidden)]
384    fn should_flush(&self) -> bool {
385        false
386    }
387
388    fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
389        Box::pin(self.clear_cached_statements())
390    }
391
392    fn dbms_name(&mut self) -> BoxFuture<'_, Result<String, Error>> {
393        Box::pin(async move {
394            self.with_conn("dbms_name", move |conn| {
395                Ok(conn.database_management_system_name()?)
396            })
397            .await
398        })
399    }
400}