Skip to main content

sqlx_odbc/
connection.rs

1use crate::{
2    OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcParameterCollection,
3    OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo, OdbcValue, OdbcValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::executor::{Execute, Executor};
12use sqlx_core::transaction::Transaction;
13use sqlx_core::Either;
14use std::future::Future;
15use std::sync::Arc;
16
17/// Blocking ODBC connection wrapper.
18///
19/// This is the minimal smoke-test surface. The SQLx async `Connection` and `Executor` traits will
20/// be implemented as the port progresses.
21pub struct OdbcConnection {
22    conn: odbc_api::Connection<'static>,
23    buffer_settings: OdbcBufferSettings,
24    transaction_depth: usize,
25}
26
27impl std::fmt::Debug for OdbcConnection {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("OdbcConnection").finish_non_exhaustive()
30    }
31}
32
33impl OdbcConnection {
34    /// Opens a blocking ODBC connection with the provided options.
35    pub fn connect_blocking(options: &OdbcConnectOptions) -> Result<Self> {
36        let env = odbc_api::environment().map_err(|error| {
37            crate::OdbcError::Configuration(format!(
38                "failed to initialize the process-wide ODBC environment: {error}"
39            ))
40        })?;
41
42        let conn = env
43            .connect_with_connection_string(options.connection_string(), Default::default())
44            .map_err(|error| {
45                crate::error::database_error_with_context(
46                    error,
47                    "failed to open ODBC connection using the supplied connection string",
48                )
49            })?;
50
51        Ok(Self {
52            conn,
53            buffer_settings: options.buffer_settings,
54            transaction_depth: 0,
55        })
56    }
57
58    /// Executes a minimal connectivity query.
59    pub fn ping_blocking(&mut self) -> Result<()> {
60        let query = self
61            .conn
62            .database_management_system_name()
63            .map(|name| ping_query_for_dbms_name(&name))
64            .unwrap_or("SELECT 1");
65        self.conn.execute(query, (), None).map_err(|error| {
66            crate::error::database_error_with_context(
67                error,
68                format!("ODBC ping query failed: `{query}`"),
69            )
70        })?;
71        Ok(())
72    }
73
74    /// Returns the DBMS name reported by the ODBC driver.
75    pub fn dbms_name(&self) -> Result<String> {
76        self.conn
77            .database_management_system_name()
78            .map_err(|error| {
79                crate::error::database_error_with_context(
80                    error,
81                    "failed to read the ODBC DBMS name from SQLGetInfo",
82                )
83            })
84    }
85
86    pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
87        if self.transaction_depth > 0 {
88            return Err(sqlx_core::Error::InvalidSavePointStatement);
89        }
90
91        self.conn.set_autocommit(false).map_err(|error| {
92            crate::error::database_error_with_context(
93                error,
94                "failed to disable ODBC autocommit while beginning a transaction",
95            )
96        })?;
97        self.transaction_depth = 1;
98        Ok(())
99    }
100
101    pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
102        if self.transaction_depth == 0 {
103            return Ok(());
104        }
105
106        self.conn.commit().map_err(|error| {
107            crate::error::database_error_with_context(
108                error,
109                "failed to commit the active ODBC transaction",
110            )
111        })?;
112        self.conn.set_autocommit(true).map_err(|error| {
113            crate::error::database_error_with_context(
114                error,
115                "failed to restore ODBC autocommit after commit",
116            )
117        })?;
118        self.transaction_depth = 0;
119        Ok(())
120    }
121
122    pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
123        if self.transaction_depth == 0 {
124            return Ok(());
125        }
126
127        self.conn.rollback().map_err(|error| {
128            crate::error::database_error_with_context(
129                error,
130                "failed to roll back the active ODBC transaction",
131            )
132        })?;
133        self.conn.set_autocommit(true).map_err(|error| {
134            crate::error::database_error_with_context(
135                error,
136                "failed to restore ODBC autocommit after rollback",
137            )
138        })?;
139        self.transaction_depth = 0;
140        Ok(())
141    }
142
143    pub(crate) fn start_rollback(&mut self) {
144        if self.transaction_depth == 0 {
145            return;
146        }
147
148        if self.conn.rollback().is_ok() {
149            let _ = self.conn.set_autocommit(true);
150            self.transaction_depth = 0;
151        }
152    }
153
154    pub(crate) const fn transaction_depth(&self) -> usize {
155        self.transaction_depth
156    }
157
158    /// Prepares a statement and returns the metadata reported by the ODBC driver.
159    pub fn prepare_blocking(
160        &mut self,
161        sql: sqlx_core::sql_str::SqlStr,
162    ) -> std::result::Result<OdbcStatement, sqlx_core::Error> {
163        let mut prepared = self.conn.prepare(sql.as_str()).map_err(|error| {
164            crate::error::database_error_with_context(
165                error,
166                format!(
167                    "failed to prepare ODBC statement: `{}`",
168                    sql_preview(sql.as_str())
169                ),
170            )
171        })?;
172        let parameters = prepared.num_params().map_err(|error| {
173            crate::error::database_error_with_context(
174                error,
175                format!(
176                    "failed to read ODBC parameter metadata for prepared statement: `{}`",
177                    sql_preview(sql.as_str())
178                ),
179            )
180        })?;
181        let columns = collect_prepared_columns(&mut prepared, parameters)?;
182
183        Ok(OdbcStatement::new(sql, columns, usize::from(parameters)))
184    }
185
186    pub(crate) fn run_blocking_sql(
187        &mut self,
188        sql: &str,
189        arguments: Option<&OdbcArguments>,
190    ) -> std::result::Result<OdbcExecution, sqlx_core::Error> {
191        let mut statement = self.conn.preallocate().map_err(|error| {
192            crate::error::database_error_with_context(
193                error,
194                format!(
195                    "failed to allocate an ODBC statement for query: `{}`",
196                    sql_preview(sql)
197                ),
198            )
199        })?;
200        let parameters = odbc_parameters(arguments);
201
202        if let Some(cursor) = statement
203            .execute(sql, parameters.as_slice())
204            .map_err(|error| {
205                crate::error::database_error_with_context(
206                    error,
207                    format!("failed to execute ODBC query: `{}`", sql_preview(sql)),
208                )
209            })?
210        {
211            return collect_rows(cursor, self.buffer_settings).map(OdbcExecution::Rows);
212        }
213
214        let rows_affected = statement
215            .row_count()
216            .map_err(|error| {
217                crate::error::database_error_with_context(
218                    error,
219                    format!(
220                        "failed to read ODBC row count for query: `{}`",
221                        sql_preview(sql)
222                    ),
223                )
224            })?
225            .unwrap_or(0);
226
227        let rows_affected = rows_affected.try_into().map_err(|_| {
228            sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned())
229        })?;
230
231        Ok(OdbcExecution::Done(OdbcQueryResult::new(rows_affected)))
232    }
233}
234
235impl sqlx_core::connection::Connection for OdbcConnection {
236    type Database = crate::Odbc;
237    type Options = OdbcConnectOptions;
238
239    async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
240        drop(self);
241        Ok(())
242    }
243
244    async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
245        drop(self);
246        Ok(())
247    }
248
249    async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
250        self.ping_blocking().map_err(Into::into)
251    }
252
253    fn begin(
254        &mut self,
255    ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
256           + Send
257           + '_ {
258        Transaction::begin(self, None)
259    }
260
261    fn shrink_buffers(&mut self) {}
262
263    async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
264        Ok(())
265    }
266
267    fn should_flush(&self) -> bool {
268        false
269    }
270}
271
272impl<'c> Executor<'c> for &'c mut OdbcConnection {
273    type Database = crate::Odbc;
274
275    fn fetch_many<'e, 'q, E>(
276        self,
277        mut query: E,
278    ) -> BoxStream<'e, std::result::Result<Either<OdbcQueryResult, OdbcRow>, sqlx_core::Error>>
279    where
280        'c: 'e,
281        E: Execute<'q, Self::Database>,
282        'q: 'e,
283        E: 'q,
284    {
285        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
286        let sql = query.sql();
287
288        stream::once(async move {
289            let arguments = arguments?;
290            self.run_blocking_sql(sql.as_str(), arguments.as_ref())
291        })
292        .map(|result| match result {
293            Ok(OdbcExecution::Done(result)) => {
294                stream::once(future::ready(Ok(Either::Left(result)))).boxed()
295            }
296            Ok(OdbcExecution::Rows(rows)) => stream::iter(
297                rows.into_iter()
298                    .map(|row| Ok(Either::Right(row)))
299                    .chain(std::iter::once(Ok(Either::Left(OdbcQueryResult::new(0))))),
300            )
301            .boxed(),
302            Err(error) => stream::once(future::ready(Err(error))).boxed(),
303        })
304        .flatten()
305        .boxed()
306    }
307
308    fn fetch_optional<'e, 'q, E>(
309        self,
310        mut query: E,
311    ) -> BoxFuture<'e, std::result::Result<Option<OdbcRow>, sqlx_core::Error>>
312    where
313        'c: 'e,
314        E: Execute<'q, Self::Database>,
315        'q: 'e,
316        E: 'q,
317    {
318        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
319        let sql = query.sql();
320
321        Box::pin(async move {
322            let arguments = arguments?;
323
324            match self.run_blocking_sql(sql.as_str(), arguments.as_ref())? {
325                OdbcExecution::Rows(rows) => Ok(rows.into_iter().next()),
326                OdbcExecution::Done(_) => Ok(None),
327            }
328        })
329    }
330
331    fn prepare_with<'e>(
332        self,
333        sql: sqlx_core::sql_str::SqlStr,
334        _parameters: &[crate::OdbcTypeInfo],
335    ) -> BoxFuture<'e, std::result::Result<OdbcStatement, sqlx_core::Error>>
336    where
337        'c: 'e,
338    {
339        Box::pin(async move { self.prepare_blocking(sql) })
340    }
341}
342
343pub(crate) enum OdbcExecution {
344    Done(OdbcQueryResult),
345    Rows(Vec<OdbcRow>),
346}
347
348fn odbc_parameters(arguments: Option<&OdbcArguments>) -> OdbcParameterCollection {
349    arguments
350        .map(OdbcArguments::to_odbc_parameter_collection)
351        .unwrap_or_default()
352}
353
354fn ping_query_for_dbms_name(dbms_name: &str) -> &'static str {
355    let dbms_name = dbms_name.to_ascii_uppercase();
356
357    if dbms_name.contains("DB2")
358        || dbms_name.contains("DB/2")
359        || dbms_name.contains("ISERIES")
360        || dbms_name.contains("AS/400")
361        || dbms_name.contains("IBM I")
362    {
363        "SELECT 1 FROM SYSIBM.SYSDUMMY1"
364    } else {
365        "SELECT 1"
366    }
367}
368
369fn collect_columns(
370    cursor: &mut impl ResultSetMetadata,
371) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
372    let count = cursor.num_result_cols().map_err(|error| {
373        crate::error::database_error_with_context(error, "failed to read ODBC result-column count")
374    })?;
375    let count = usize::try_from(count).map_err(|_| {
376        sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
377    })?;
378
379    let mut columns = Vec::with_capacity(count);
380    for ordinal in 0..count {
381        let column_number = u16::try_from(ordinal + 1).map_err(|_| {
382            sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
383        })?;
384
385        let mut description = odbc_api::ColumnDescription::default();
386        cursor
387            .describe_col(column_number, &mut description)
388            .map_err(|error| {
389                crate::error::database_error_with_context(
390                    error,
391                    format!("failed to describe ODBC result column {column_number}"),
392                )
393            })?;
394        let name = description
395            .name_to_string()
396            .unwrap_or_else(|_| format!("col{ordinal}"));
397
398        columns.push(OdbcColumn::new(
399            ordinal,
400            name,
401            OdbcTypeInfo::new(description.data_type),
402        ));
403    }
404
405    Ok(columns)
406}
407
408fn collect_prepared_columns(
409    prepared: &mut impl PreparedStatementMetadata,
410    parameter_count: u16,
411) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
412    match collect_columns(prepared) {
413        Ok(columns) => Ok(columns),
414        Err(error) if parameter_count > 0 => {
415            validate_parameter_metadata(prepared, parameter_count)?;
416            log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
417            Ok(Vec::new())
418        }
419        Err(error) => Err(error),
420    }
421}
422
423trait PreparedStatementMetadata: ResultSetMetadata {
424    fn describe_prepared_parameter(
425        &mut self,
426        index: u16,
427    ) -> std::result::Result<(), odbc_api::Error>;
428}
429
430impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
431where
432    S: odbc_api::handles::AsStatementRef,
433{
434    fn describe_prepared_parameter(
435        &mut self,
436        index: u16,
437    ) -> std::result::Result<(), odbc_api::Error> {
438        self.describe_param(index).map(|_| ())
439    }
440}
441
442fn validate_parameter_metadata(
443    prepared: &mut impl PreparedStatementMetadata,
444    parameter_count: u16,
445) -> std::result::Result<(), sqlx_core::Error> {
446    for index in 1..=parameter_count {
447        prepared
448            .describe_prepared_parameter(index)
449            .map_err(|error| {
450                crate::error::database_error_with_context(
451                    error,
452                    format!("failed to describe ODBC parameter {index}"),
453                )
454            })?;
455    }
456
457    Ok(())
458}
459
460fn collect_rows<C>(
461    cursor: C,
462    settings: OdbcBufferSettings,
463) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
464where
465    C: Cursor + ResultSetMetadata,
466{
467    if let Some(max_column_size) = settings.max_column_size {
468        collect_rows_buffered(cursor, settings.batch_size, max_column_size)
469    } else {
470        collect_rows_unbuffered(cursor)
471    }
472}
473
474#[derive(Debug)]
475struct ColumnBinding {
476    column: OdbcColumn,
477    buffer_desc: BufferDesc,
478}
479
480fn collect_rows_buffered<C>(
481    cursor: C,
482    batch_size: usize,
483    max_column_size: usize,
484) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
485where
486    C: Cursor + ResultSetMetadata,
487{
488    let mut cursor = cursor;
489    let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
490    let buffer_descriptions = bindings
491        .iter()
492        .map(|binding| binding.buffer_desc)
493        .collect::<Vec<_>>();
494    let mut row_set_cursor = cursor
495        .bind_buffer(ColumnarDynBuffer::from_descs(
496            batch_size,
497            buffer_descriptions,
498        ))
499        .map_err(|error| {
500            crate::error::database_error_with_context(
501                error,
502                format!(
503                    "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
504                     this driver may reject the row-array or row-binding statement attributes \
505                     used for column-wise buffered fetching, so use \
506                     OdbcConnectOptions::max_column_size(None) to fetch rows unbuffered"
507                ),
508            )
509        })?;
510    let columns: Arc<[OdbcColumn]> = bindings
511        .iter()
512        .map(|binding| binding.column.clone())
513        .collect::<Vec<_>>()
514        .into();
515    let mut rows = Vec::new();
516
517    while let Some(batch) = row_set_cursor.fetch().map_err(|error| {
518        crate::error::database_error_with_context(error, "ODBC buffered fetch failed")
519    })? {
520        let column_values = bindings
521            .iter()
522            .enumerate()
523            .map(|(index, binding)| {
524                buffered_column_values(batch.column(index), binding).map_err(|error| {
525                    sqlx_core::Error::Protocol(format!(
526                        "ODBC buffered fetch could not convert column {} (`{}`) using buffer {:?}: {error}",
527                        binding.column.ordinal() + 1,
528                        binding.column.name(),
529                        binding.buffer_desc
530                    ))
531                })
532            })
533            .collect::<std::result::Result<Vec<_>, _>>()?;
534
535        let mut column_iters = column_values
536            .into_iter()
537            .map(Vec::into_iter)
538            .collect::<Vec<_>>();
539
540        for row_index in 0..batch.num_rows() {
541            let values = column_iters
542                .iter_mut()
543                .map(|values| {
544                    values.next().map(OdbcValue::new).ok_or_else(|| {
545                        sqlx_core::Error::Protocol(format!(
546                            "ODBC buffered fetch produced too few values for row {}",
547                            row_index + 1
548                        ))
549                    })
550                })
551                .collect::<std::result::Result<Vec<_>, _>>()?;
552            rows.push(OdbcRow::new_shared(Arc::clone(&columns), values));
553        }
554    }
555
556    Ok(rows)
557}
558
559fn build_buffer_bindings(
560    cursor: &mut impl ResultSetMetadata,
561    max_column_size: usize,
562) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
563    collect_columns(cursor).map(|columns| {
564        columns
565            .into_iter()
566            .map(|column| ColumnBinding {
567                buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
568                column,
569            })
570            .collect()
571    })
572}
573
574fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
575    match data_type {
576        DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
577            BufferDesc::I64 { nullable: true }
578        }
579        DataType::Real => BufferDesc::F32 { nullable: true },
580        DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
581        DataType::Bit => BufferDesc::Bit { nullable: true },
582        DataType::Date => BufferDesc::Date { nullable: true },
583        DataType::Time { .. } => BufferDesc::Time { nullable: true },
584        DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
585        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
586            BufferDesc::Binary {
587                max_bytes: max_column_size,
588            }
589        }
590        DataType::Char { .. }
591        | DataType::WChar { .. }
592        | DataType::Varchar { .. }
593        | DataType::WVarchar { .. }
594        | DataType::LongVarchar { .. }
595        | DataType::WLongVarchar { .. }
596        | DataType::Other { .. }
597        | DataType::Unknown
598        | DataType::Decimal { .. }
599        | DataType::Numeric { .. } => BufferDesc::Text {
600            max_str_len: max_column_size,
601        },
602    }
603}
604
605fn buffered_column_values(
606    slice: AnyColumnBufferSlice<'_>,
607    binding: &ColumnBinding,
608) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error> {
609    let desc = binding.buffer_desc;
610    Ok(match desc {
611        BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
612            OdbcValueKind::TinyInt(value)
613        })?,
614        BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
615            OdbcValueKind::SmallInt(value)
616        })?,
617        BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
618            OdbcValueKind::Integer(value)
619        })?,
620        BufferDesc::I64 { nullable } => {
621            buffered_numeric(&slice, desc, nullable, OdbcValueKind::BigInt)?
622        }
623        BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
624            OdbcValueKind::BigInt(i64::from(value))
625        })?,
626        BufferDesc::F32 { nullable } => {
627            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Real)?
628        }
629        BufferDesc::F64 { nullable } => {
630            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Double)?
631        }
632        BufferDesc::Bit { nullable } => {
633            buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
634                OdbcValueKind::Bit(value.as_bool())
635            })?
636        }
637        BufferDesc::Date { nullable } => {
638            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Date)?
639        }
640        BufferDesc::Time { nullable } => {
641            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Time)?
642        }
643        BufferDesc::Timestamp { nullable } => {
644            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Timestamp)?
645        }
646        BufferDesc::Text { .. } => {
647            let text = expect_buffer_slice(slice.as_text(), desc)?;
648            text.iter()
649                .map(|value| {
650                    value
651                        .map(|bytes| {
652                            OdbcValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
653                        })
654                        .unwrap_or(OdbcValueKind::Null)
655                })
656                .collect()
657        }
658        BufferDesc::WText { .. } => {
659            let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
660            text.iter()
661                .map(|value| {
662                    value
663                        .map(|chars| OdbcValueKind::Text(String::from_utf16_lossy(chars.into())))
664                        .unwrap_or(OdbcValueKind::Null)
665                })
666                .collect()
667        }
668        BufferDesc::Binary { .. } => {
669            let binary = expect_buffer_slice(slice.as_binary(), desc)?;
670            binary
671                .iter()
672                .map(|value| {
673                    value
674                        .map(|bytes| OdbcValueKind::Binary(bytes.to_vec()))
675                        .unwrap_or(OdbcValueKind::Null)
676                })
677                .collect()
678        }
679        BufferDesc::Numeric => {
680            return Err(sqlx_core::Error::Protocol(format!(
681                "unsupported ODBC buffer descriptor: {desc:?}"
682            )))
683        }
684    })
685}
686
687fn buffered_numeric<T, F>(
688    slice: &AnyColumnBufferSlice<'_>,
689    desc: BufferDesc,
690    nullable: bool,
691    map: F,
692) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error>
693where
694    T: Copy + odbc_api::Pod,
695    F: FnMut(T) -> OdbcValueKind,
696{
697    if nullable {
698        Ok(buffered_nullable_numeric(
699            expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
700            map,
701        ))
702    } else {
703        Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
704            .iter()
705            .copied()
706            .map(map)
707            .collect())
708    }
709}
710
711fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<OdbcValueKind>
712where
713    T: Copy,
714    F: FnMut(T) -> OdbcValueKind,
715{
716    slice
717        .map(|value| value.copied().map(&mut map).unwrap_or(OdbcValueKind::Null))
718        .collect()
719}
720
721fn expect_buffer_slice<T>(
722    slice: Option<T>,
723    desc: BufferDesc,
724) -> std::result::Result<T, sqlx_core::Error> {
725    slice.ok_or_else(|| {
726        sqlx_core::Error::Protocol(format!(
727            "ODBC column buffer {desc:?} did not match fetched slice"
728        ))
729    })
730}
731
732fn collect_rows_unbuffered<C>(mut cursor: C) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
733where
734    C: Cursor + ResultSetMetadata,
735{
736    let columns: Arc<[OdbcColumn]> = collect_columns(&mut cursor)?.into();
737    let mut rows = Vec::new();
738
739    while let Some(mut cursor_row) = cursor.next_row().map_err(|error| {
740        crate::error::database_error_with_context(
741            error,
742            "ODBC unbuffered fetch failed while reading the next row",
743        )
744    })? {
745        let mut values = Vec::with_capacity(columns.len());
746
747        for column in columns.iter() {
748            let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
749                .map_err(|_| {
750                    sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
751                })?;
752            values.push(fetch_value(&mut cursor_row, column_number, column)?);
753        }
754
755        rows.push(OdbcRow::new_shared(Arc::clone(&columns), values));
756    }
757
758    Ok(rows)
759}
760
761fn fetch_value(
762    row: &mut odbc_api::CursorRow<'_>,
763    column_number: u16,
764    column: &OdbcColumn,
765) -> std::result::Result<OdbcValue, sqlx_core::Error> {
766    let data_type = column.type_info().data_type();
767
768    let kind = match data_type {
769        DataType::Bit => {
770            let mut value = Nullable::<odbc_api::Bit>::null();
771            row.get_data(column_number, &mut value).map_err(|error| {
772                crate::error::database_error_with_context_lazy(error, || {
773                    fetch_context(column, data_type)
774                })
775            })?;
776            value
777                .into_opt()
778                .map(|value| OdbcValueKind::Bit(value.as_bool()))
779                .unwrap_or(OdbcValueKind::Null)
780        }
781        DataType::TinyInt => fetch_nullable(
782            row,
783            column_number,
784            column,
785            data_type,
786            OdbcValueKind::TinyInt,
787        )?,
788        DataType::SmallInt => fetch_nullable(
789            row,
790            column_number,
791            column,
792            data_type,
793            OdbcValueKind::SmallInt,
794        )?,
795        DataType::Integer => fetch_nullable(
796            row,
797            column_number,
798            column,
799            data_type,
800            OdbcValueKind::Integer,
801        )?,
802        DataType::BigInt => {
803            fetch_nullable(row, column_number, column, data_type, OdbcValueKind::BigInt)?
804        }
805        DataType::Real => {
806            fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Real)?
807        }
808        DataType::Float { .. } | DataType::Double => {
809            fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Double)?
810        }
811        DataType::Date => {
812            fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Date)?
813        }
814        DataType::Time { .. } => {
815            fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Time)?
816        }
817        DataType::Timestamp { .. } => fetch_nullable(
818            row,
819            column_number,
820            column,
821            data_type,
822            OdbcValueKind::Timestamp,
823        )?,
824        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
825            let mut value = Vec::new();
826            if row.get_binary(column_number, &mut value).map_err(|error| {
827                crate::error::database_error_with_context_lazy(error, || {
828                    fetch_context(column, data_type)
829                })
830            })? {
831                OdbcValueKind::Binary(value)
832            } else {
833                OdbcValueKind::Null
834            }
835        }
836        _ => {
837            let mut value = Vec::new();
838            if row
839                .get_wide_text(column_number, &mut value)
840                .map_err(|error| {
841                    crate::error::database_error_with_context_lazy(error, || {
842                        fetch_context(column, data_type)
843                    })
844                })?
845            {
846                OdbcValueKind::Text(String::from_utf16_lossy(&value))
847            } else {
848                OdbcValueKind::Null
849            }
850        }
851    };
852
853    Ok(OdbcValue::new(kind))
854}
855
856fn fetch_nullable<T, F>(
857    row: &mut odbc_api::CursorRow<'_>,
858    column_number: u16,
859    column: &OdbcColumn,
860    data_type: DataType,
861    map: F,
862) -> std::result::Result<OdbcValueKind, sqlx_core::Error>
863where
864    T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
865    Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
866    F: FnOnce(T) -> OdbcValueKind,
867{
868    let mut value = Nullable::<T>::null();
869    row.get_data(column_number, &mut value).map_err(|error| {
870        crate::error::database_error_with_context_lazy(error, || fetch_context(column, data_type))
871    })?;
872    Ok(value.into_opt().map(map).unwrap_or(OdbcValueKind::Null))
873}
874
875fn fetch_context(column: &OdbcColumn, data_type: DataType) -> String {
876    format!(
877        "failed to fetch ODBC column {} (`{}`) as {data_type:?}",
878        column.ordinal() + 1,
879        column.name()
880    )
881}
882
883fn sql_preview(sql: &str) -> String {
884    const MAX_LEN: usize = 160;
885
886    let compact = sql.split_whitespace().collect::<Vec<_>>().join(" ");
887    if compact.len() <= MAX_LEN {
888        compact
889    } else {
890        let mut preview = compact.chars().take(MAX_LEN - 3).collect::<String>();
891        preview.push_str("...");
892        preview
893    }
894}
895
896#[cfg(test)]
897mod tests {
898    use super::*;
899
900    #[test]
901    fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
902        assert!(matches!(
903            map_buffer_desc(DataType::TinyInt, 64),
904            BufferDesc::I64 { nullable: true }
905        ));
906        assert!(matches!(
907            map_buffer_desc(DataType::Integer, 64),
908            BufferDesc::I64 { nullable: true }
909        ));
910        assert!(matches!(
911            map_buffer_desc(DataType::BigInt, 64),
912            BufferDesc::I64 { nullable: true }
913        ));
914    }
915
916    #[test]
917    fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
918        assert_eq!(
919            map_buffer_desc(DataType::Varchar { length: None }, 32),
920            BufferDesc::Text { max_str_len: 32 }
921        );
922        assert_eq!(
923            map_buffer_desc(DataType::Varbinary { length: None }, 16),
924            BufferDesc::Binary { max_bytes: 16 }
925        );
926    }
927
928    #[test]
929    fn ping_query_uses_db2_dummy_table_for_db2_drivers() {
930        assert_eq!(
931            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
932            ping_query_for_dbms_name("DB2")
933        );
934        assert_eq!(
935            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
936            ping_query_for_dbms_name("DB2 UDB for AS/400")
937        );
938        assert_eq!(
939            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
940            ping_query_for_dbms_name("IBM DB2 for i")
941        );
942        assert_eq!(
943            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
944            ping_query_for_dbms_name("iSeries")
945        );
946        assert_eq!(
947            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
948            ping_query_for_dbms_name("IBM i")
949        );
950    }
951
952    #[test]
953    fn ping_query_keeps_select_one_for_non_db2_drivers() {
954        assert_eq!("SELECT 1", ping_query_for_dbms_name("DuckDB"));
955        assert_eq!("SELECT 1", ping_query_for_dbms_name("Microsoft SQL Server"));
956        assert_eq!("SELECT 1", ping_query_for_dbms_name("PostgreSQL"));
957    }
958}