Skip to main content

sqlx_core_oldapi/odbc/connection/
mod.rs

1use crate::common::StatementCache;
2use crate::connection::{Connection, LogSettings};
3use crate::error::Error;
4use crate::odbc::{
5    Odbc, OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcQueryResult,
6    OdbcRow, OdbcTypeInfo,
7};
8use crate::transaction::Transaction;
9use either::Either;
10use sqlx_rt::spawn_blocking;
11mod odbc_bridge;
12use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
13use futures_core::future::BoxFuture;
14use futures_util::future;
15use odbc_api::ConnectionTransitions;
16use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection};
17use odbc_bridge::{establish_connection, execute_sql};
18use std::borrow::Cow;
19use std::sync::{Arc, Mutex};
20
21mod executor;
22
23type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
24type SharedPreparedStatement = Arc<Mutex<PreparedStatement>>;
25
26fn collect_columns(
27    prepared: &mut PreparedStatement,
28    parameter_count: usize,
29    allow_deferred_result_columns: bool,
30) -> Result<Vec<OdbcColumn>, Error> {
31    let count = match prepared.num_result_cols() {
32        Ok(count) => count,
33        Err(error) if allow_deferred_result_columns && parameter_count > 0 => {
34            // Some ODBC drivers only expose result columns for parameterized
35            // statements after values are bound and the statement is executed.
36            // Row execution still gets authoritative metadata from the cursor;
37            // describe() uses the strict path and will surface this error.
38            log::debug!("ODBC prepare did not expose result columns: {error}");
39            return Ok(Vec::new());
40        }
41        Err(error) => return Err(error.into()),
42    };
43
44    let mut columns = Vec::with_capacity(count as usize);
45    for i in 1..=count {
46        columns.push(describe_column(prepared, i as u16)?);
47    }
48    Ok(columns)
49}
50
51fn collect_statement_metadata(
52    prepared: &mut PreparedStatement,
53    allow_deferred_result_columns: bool,
54) -> Result<OdbcStatementMetadata, Error> {
55    let parameters = usize::from(prepared.num_params()?);
56
57    Ok(OdbcStatementMetadata {
58        columns: collect_columns(prepared, parameters, allow_deferred_result_columns)?,
59        parameters,
60    })
61}
62
63pub(super) fn describe_column<S>(stmt: &mut S, index: u16) -> Result<OdbcColumn, Error>
64where
65    S: ResultSetMetadata,
66{
67    let mut cd = odbc_api::ColumnDescription::default();
68    stmt.describe_col(index, &mut cd)?;
69
70    Ok(OdbcColumn {
71        name: decode_column_name(cd.name, index),
72        type_info: OdbcTypeInfo::new(cd.data_type),
73        ordinal: usize::from(
74            index
75                .checked_sub(1)
76                .ok_or_else(|| Error::Protocol("ODBC column indices are 1-based".into()))?,
77        ),
78    })
79}
80
81pub(super) trait ColumnNameDecode {
82    fn decode_or_default(self, index: u16) -> String;
83}
84
85impl ColumnNameDecode for Vec<u8> {
86    fn decode_or_default(self, index: u16) -> String {
87        String::from_utf8(self).unwrap_or_else(|_| format!("col{}", index - 1))
88    }
89}
90
91impl ColumnNameDecode for Vec<u16> {
92    fn decode_or_default(self, index: u16) -> String {
93        String::from_utf16(&self).unwrap_or_else(|_| format!("col{}", index - 1))
94    }
95}
96
97pub(super) fn decode_column_name<T: ColumnNameDecode>(name: T, index: u16) -> String {
98    name.decode_or_default(index)
99}
100
101/// A connection to an ODBC-accessible database.
102///
103/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking
104/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
105pub struct OdbcConnection {
106    pub(crate) conn: SharedConnection<'static>,
107    pub(crate) stmt_cache: StatementCache<SharedPreparedStatement>,
108    pub(crate) buffer_settings: OdbcBufferSettings,
109    pub(crate) log_settings: LogSettings,
110}
111
112impl std::fmt::Debug for OdbcConnection {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("OdbcConnection")
115            .field("conn", &self.conn)
116            .field("buffer_settings", &self.buffer_settings)
117            .finish()
118    }
119}
120
121impl OdbcConnection {
122    pub(crate) async fn with_conn<R, F, S>(&mut self, operation: S, f: F) -> Result<R, Error>
123    where
124        R: Send + 'static,
125        F: FnOnce(&mut odbc_api::Connection<'static>) -> Result<R, Error> + Send + 'static,
126        S: std::fmt::Display + Send + 'static,
127    {
128        let conn = Arc::clone(&self.conn);
129        spawn_blocking(move || {
130            let mut conn_guard = conn.lock().map_err(|_| {
131                Error::Protocol(format!("ODBC {}: failed to lock connection", operation))
132            })?;
133            f(&mut conn_guard)
134        })
135        .await
136    }
137
138    pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
139        let shared_conn = spawn_blocking({
140            let options = options.clone();
141            move || {
142                let conn = establish_connection(&options)?;
143                let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
144                Ok::<_, Error>(shared_conn)
145            }
146        })
147        .await?;
148
149        Ok(Self {
150            conn: shared_conn,
151            stmt_cache: StatementCache::new(options.statement_cache_capacity),
152            buffer_settings: options.buffer_settings,
153            log_settings: options.log_settings.clone(),
154        })
155    }
156
157    pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
158        self.with_conn("ping", move |conn| {
159            conn.execute("SELECT 1", (), None)?;
160            Ok(())
161        })
162        .await
163    }
164
165    pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
166        self.with_conn("begin", move |conn| {
167            conn.set_autocommit(false)?;
168            Ok(())
169        })
170        .await
171    }
172
173    pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
174        self.with_conn("commit", move |conn| {
175            conn.commit()?;
176            conn.set_autocommit(true)?;
177            Ok(())
178        })
179        .await
180    }
181
182    pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
183        self.with_conn("rollback", move |conn| {
184            conn.rollback()?;
185            conn.set_autocommit(true)?;
186            Ok(())
187        })
188        .await
189    }
190
191    /// Launches a background task to execute the SQL statement and send the results to the returned channel.
192    pub(crate) fn execute_stream(
193        &mut self,
194        sql: &str,
195        args: Option<OdbcArguments>,
196    ) -> flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>> {
197        let (tx, rx) = flume::bounded(64);
198
199        let sql_owned = sql.to_string();
200        let maybe_prepared = if let Some(prepared) = self.stmt_cache.get_mut(sql) {
201            MaybePrepared::Prepared(Arc::clone(prepared))
202        } else {
203            MaybePrepared::NotPrepared(sql_owned.clone())
204        };
205
206        let conn = Arc::clone(&self.conn);
207        let buffer_settings = self.buffer_settings;
208        let log_settings = self.log_settings.clone();
209        sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || {
210            let mut logger = crate::logger::QueryLogger::new(&sql_owned, log_settings);
211            let result = conn
212                .lock()
213                .map_err(|_| Error::Protocol("ODBC execute: failed to lock connection".into()))
214                .and_then(|mut conn| {
215                    execute_sql(
216                        &mut conn,
217                        maybe_prepared,
218                        args,
219                        &tx,
220                        buffer_settings,
221                        &mut logger,
222                    )
223                });
224
225            if let Err(e) = result {
226                let _ = tx.send(Err(e));
227            }
228        }));
229
230        rx
231    }
232
233    pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
234        while self.stmt_cache.remove_lru().is_some() {}
235        Ok(())
236    }
237
238    async fn prepare_with_metadata_policy<'a>(
239        &mut self,
240        sql: &'a str,
241        store_to_cache: bool,
242        allow_deferred_result_columns: bool,
243    ) -> Result<OdbcStatement<'a>, Error> {
244        let cached = self
245            .stmt_cache
246            .get_mut(sql)
247            .map(|prepared| Arc::clone(prepared));
248
249        if let Some(prepared) = cached {
250            let metadata = spawn_blocking(move || {
251                let mut prepared = prepared.lock().map_err(|_| {
252                    Error::Protocol("ODBC prepare: failed to lock prepared statement".into())
253                })?;
254                collect_statement_metadata(&mut prepared, allow_deferred_result_columns)
255            })
256            .await?;
257
258            return Ok(OdbcStatement {
259                sql: Cow::Borrowed(sql),
260                metadata,
261            });
262        }
263
264        let conn = Arc::clone(&self.conn);
265        let sql_owned = sql.to_string();
266        let sql_clone = sql_owned.clone();
267        let (prepared, metadata) = spawn_blocking(move || {
268            let mut prepared = conn.into_prepared(&sql_clone)?;
269            let metadata =
270                collect_statement_metadata(&mut prepared, allow_deferred_result_columns)?;
271            Ok::<_, Error>((prepared, metadata))
272        })
273        .await?;
274
275        if store_to_cache && self.stmt_cache.is_enabled() {
276            self.stmt_cache
277                .insert(&sql_owned, Arc::new(Mutex::new(prepared)));
278        }
279
280        Ok(OdbcStatement {
281            sql: Cow::Borrowed(sql),
282            metadata,
283        })
284    }
285
286    pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
287        self.prepare_with_metadata_policy(sql, true, true).await
288    }
289
290    pub(crate) async fn describe_statement<'a>(
291        &mut self,
292        sql: &'a str,
293    ) -> Result<OdbcStatement<'a>, Error> {
294        self.prepare_with_metadata_policy(sql, false, false).await
295    }
296}
297
298pub(crate) enum MaybePrepared {
299    Prepared(SharedPreparedStatement),
300    NotPrepared(String),
301}
302
303impl std::fmt::Debug for MaybePrepared {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            MaybePrepared::Prepared(_) => f.debug_tuple("Prepared").finish(),
307            MaybePrepared::NotPrepared(sql) => f.debug_tuple("NotPrepared").field(sql).finish(),
308        }
309    }
310}
311
312impl Connection for OdbcConnection {
313    type Database = Odbc;
314
315    type Options = OdbcConnectOptions;
316
317    fn close(self) -> BoxFuture<'static, Result<(), Error>> {
318        Box::pin(async move {
319            // Drop connection by moving Arc and letting it fall out of scope.
320            drop(self);
321            Ok(())
322        })
323    }
324
325    fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
326        Box::pin(async move { Ok(()) })
327    }
328
329    fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
330        Box::pin(self.ping_blocking())
331    }
332
333    fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
334    where
335        Self: Sized,
336    {
337        Transaction::begin(self)
338    }
339
340    fn cached_statements_size(&self) -> usize {
341        self.stmt_cache.len()
342    }
343
344    #[doc(hidden)]
345    fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
346        Box::pin(future::ok(()))
347    }
348
349    #[doc(hidden)]
350    fn should_flush(&self) -> bool {
351        false
352    }
353
354    fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
355        Box::pin(self.clear_cached_statements())
356    }
357
358    fn dbms_name(&mut self) -> BoxFuture<'_, Result<String, Error>> {
359        Box::pin(async move {
360            self.with_conn("dbms_name", move |conn| {
361                Ok(conn.database_management_system_name()?)
362            })
363            .await
364        })
365    }
366}