sqlx_core_oldapi/odbc/connection/
mod.rs

1use crate::connection::Connection;
2use crate::error::Error;
3use crate::odbc::{
4    Odbc, OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcQueryResult,
5    OdbcRow, OdbcTypeInfo,
6};
7use crate::transaction::Transaction;
8use either::Either;
9use sqlx_rt::spawn_blocking;
10mod odbc_bridge;
11use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
12use futures_core::future::BoxFuture;
13use futures_util::future;
14use odbc_api::ConnectionTransitions;
15use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection};
16use odbc_bridge::{establish_connection, execute_sql};
17use std::borrow::Cow;
18use std::collections::HashMap;
19use std::sync::{Arc, Mutex};
20
21mod executor;
22
23type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
24type SharedPreparedStatement = Arc<Mutex<PreparedStatement>>;
25
26fn collect_columns(prepared: &mut PreparedStatement) -> Vec<OdbcColumn> {
27    let count = prepared.num_result_cols().unwrap_or(0);
28    (1..=count)
29        .map(|i| create_column(prepared, i as u16))
30        .collect()
31}
32
33fn create_column(stmt: &mut PreparedStatement, index: u16) -> OdbcColumn {
34    let mut cd = odbc_api::ColumnDescription::default();
35    let _ = stmt.describe_col(index, &mut cd);
36
37    OdbcColumn {
38        name: decode_column_name(cd.name, index),
39        type_info: OdbcTypeInfo::new(cd.data_type),
40        ordinal: usize::from(index.checked_sub(1).unwrap()),
41    }
42}
43
44pub(super) trait ColumnNameDecode {
45    fn decode_or_default(self, index: u16) -> String;
46}
47
48impl ColumnNameDecode for Vec<u8> {
49    fn decode_or_default(self, index: u16) -> String {
50        String::from_utf8(self).unwrap_or_else(|_| format!("col{}", index - 1))
51    }
52}
53
54impl ColumnNameDecode for Vec<u16> {
55    fn decode_or_default(self, index: u16) -> String {
56        String::from_utf16(&self).unwrap_or_else(|_| format!("col{}", index - 1))
57    }
58}
59
60pub(super) fn decode_column_name<T: ColumnNameDecode>(name: T, index: u16) -> String {
61    name.decode_or_default(index)
62}
63
64/// A connection to an ODBC-accessible database.
65///
66/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking
67/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
68pub struct OdbcConnection {
69    pub(crate) conn: SharedConnection<'static>,
70    pub(crate) stmt_cache: HashMap<Arc<str>, SharedPreparedStatement>,
71    pub(crate) buffer_settings: OdbcBufferSettings,
72}
73
74impl std::fmt::Debug for OdbcConnection {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("OdbcConnection")
77            .field("conn", &self.conn)
78            .field("buffer_settings", &self.buffer_settings)
79            .finish()
80    }
81}
82
83impl OdbcConnection {
84    pub(crate) async fn with_conn<R, F, S>(&mut self, operation: S, f: F) -> Result<R, Error>
85    where
86        R: Send + 'static,
87        F: FnOnce(&mut odbc_api::Connection<'static>) -> Result<R, Error> + Send + 'static,
88        S: std::fmt::Display + Send + 'static,
89    {
90        let conn = Arc::clone(&self.conn);
91        spawn_blocking(move || {
92            let mut conn_guard = conn.lock().map_err(|_| {
93                Error::Protocol(format!("ODBC {}: failed to lock connection", operation))
94            })?;
95            f(&mut conn_guard)
96        })
97        .await
98    }
99
100    pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
101        let shared_conn = spawn_blocking({
102            let options = options.clone();
103            move || {
104                let conn = establish_connection(&options)?;
105                let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
106                Ok::<_, Error>(shared_conn)
107            }
108        })
109        .await?;
110
111        Ok(Self {
112            conn: shared_conn,
113            stmt_cache: HashMap::new(),
114            buffer_settings: options.buffer_settings,
115        })
116    }
117
118    // (dbms_name moved to the Connection trait implementation)
119
120    pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
121        self.with_conn("ping", move |conn| {
122            conn.execute("SELECT 1", (), None)?;
123            Ok(())
124        })
125        .await
126    }
127
128    pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
129        self.with_conn("begin", move |conn| {
130            conn.set_autocommit(false)?;
131            Ok(())
132        })
133        .await
134    }
135
136    pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
137        self.with_conn("commit", move |conn| {
138            conn.commit()?;
139            conn.set_autocommit(true)?;
140            Ok(())
141        })
142        .await
143    }
144
145    pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
146        self.with_conn("rollback", move |conn| {
147            conn.rollback()?;
148            conn.set_autocommit(true)?;
149            Ok(())
150        })
151        .await
152    }
153
154    /// Launches a background task to execute the SQL statement and send the results to the returned channel.
155    pub(crate) fn execute_stream(
156        &mut self,
157        sql: &str,
158        args: Option<OdbcArguments>,
159    ) -> flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>> {
160        let (tx, rx) = flume::bounded(64);
161
162        let maybe_prepared = if let Some(prepared) = self.stmt_cache.get(sql) {
163            MaybePrepared::Prepared(Arc::clone(prepared))
164        } else {
165            MaybePrepared::NotPrepared(sql.to_string())
166        };
167
168        let conn = Arc::clone(&self.conn);
169        let buffer_settings = self.buffer_settings;
170        sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || {
171            let mut conn = conn.lock().expect("failed to lock connection");
172            if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx, buffer_settings) {
173                let _ = tx.send(Err(e));
174            }
175        }));
176
177        rx
178    }
179
180    pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
181        // Clear the statement metadata cache
182        self.stmt_cache.clear();
183        Ok(())
184    }
185
186    pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
187        let conn = Arc::clone(&self.conn);
188        let sql_arc = Arc::from(sql.to_string());
189        let sql_clone = Arc::clone(&sql_arc);
190        let (prepared, metadata) = spawn_blocking(move || {
191            let mut prepared = conn.into_prepared(&sql_clone)?;
192            let metadata = OdbcStatementMetadata {
193                columns: collect_columns(&mut prepared),
194                parameters: usize::from(prepared.num_params().unwrap_or(0)),
195            };
196            Ok::<_, Error>((prepared, metadata))
197        })
198        .await?;
199        self.stmt_cache
200            .insert(Arc::clone(&sql_arc), Arc::new(Mutex::new(prepared)));
201        Ok(OdbcStatement {
202            sql: Cow::Borrowed(sql),
203            metadata,
204        })
205    }
206}
207
208pub(crate) enum MaybePrepared {
209    Prepared(SharedPreparedStatement),
210    NotPrepared(String),
211}
212
213impl Connection for OdbcConnection {
214    type Database = Odbc;
215
216    type Options = OdbcConnectOptions;
217
218    fn close(self) -> BoxFuture<'static, Result<(), Error>> {
219        Box::pin(async move {
220            // Drop connection by moving Arc and letting it fall out of scope.
221            drop(self);
222            Ok(())
223        })
224    }
225
226    fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
227        Box::pin(async move { Ok(()) })
228    }
229
230    fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
231        Box::pin(self.ping_blocking())
232    }
233
234    fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
235    where
236        Self: Sized,
237    {
238        Transaction::begin(self)
239    }
240
241    #[doc(hidden)]
242    fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
243        Box::pin(future::ok(()))
244    }
245
246    #[doc(hidden)]
247    fn should_flush(&self) -> bool {
248        false
249    }
250
251    fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
252        Box::pin(self.clear_cached_statements())
253    }
254
255    fn dbms_name(&mut self) -> BoxFuture<'_, Result<String, Error>> {
256        Box::pin(async move {
257            self.with_conn("dbms_name", move |conn| {
258                Ok(conn.database_management_system_name()?)
259            })
260            .await
261        })
262    }
263}
264
265// moved helpers to connection/inner.rs