Skip to main content

sqlx_odbc/
connection.rs

1use crate::{
2    OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcParameterCollection,
3    OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo, OdbcValue, OdbcValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::executor::{Execute, Executor};
12use sqlx_core::transaction::Transaction;
13use sqlx_core::Either;
14use std::future::Future;
15
16/// Blocking ODBC connection wrapper.
17///
18/// This is the minimal smoke-test surface. The SQLx async `Connection` and `Executor` traits will
19/// be implemented as the port progresses.
20pub struct OdbcConnection {
21    conn: odbc_api::Connection<'static>,
22    buffer_settings: OdbcBufferSettings,
23    transaction_depth: usize,
24}
25
26impl std::fmt::Debug for OdbcConnection {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("OdbcConnection").finish_non_exhaustive()
29    }
30}
31
32impl OdbcConnection {
33    /// Opens a blocking ODBC connection with the provided options.
34    pub fn connect_blocking(options: &OdbcConnectOptions) -> Result<Self> {
35        let env = odbc_api::environment()
36            .map_err(|error| crate::OdbcError::Configuration(error.to_string()))?;
37
38        let conn =
39            env.connect_with_connection_string(options.connection_string(), Default::default())?;
40
41        Ok(Self {
42            conn,
43            buffer_settings: options.buffer_settings,
44            transaction_depth: 0,
45        })
46    }
47
48    /// Executes a minimal connectivity query.
49    pub fn ping_blocking(&mut self) -> Result<()> {
50        let query = self
51            .conn
52            .database_management_system_name()
53            .map(|name| ping_query_for_dbms_name(&name))
54            .unwrap_or("SELECT 1");
55        self.conn.execute(query, (), None)?;
56        Ok(())
57    }
58
59    /// Returns the DBMS name reported by the ODBC driver.
60    pub fn dbms_name(&self) -> Result<String> {
61        Ok(self.conn.database_management_system_name()?)
62    }
63
64    pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
65        if self.transaction_depth > 0 {
66            return Err(sqlx_core::Error::InvalidSavePointStatement);
67        }
68
69        self.conn
70            .set_autocommit(false)
71            .map_err(crate::OdbcError::from)?;
72        self.transaction_depth = 1;
73        Ok(())
74    }
75
76    pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
77        if self.transaction_depth == 0 {
78            return Ok(());
79        }
80
81        self.conn.commit().map_err(crate::OdbcError::from)?;
82        self.conn
83            .set_autocommit(true)
84            .map_err(crate::OdbcError::from)?;
85        self.transaction_depth = 0;
86        Ok(())
87    }
88
89    pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
90        if self.transaction_depth == 0 {
91            return Ok(());
92        }
93
94        self.conn.rollback().map_err(crate::OdbcError::from)?;
95        self.conn
96            .set_autocommit(true)
97            .map_err(crate::OdbcError::from)?;
98        self.transaction_depth = 0;
99        Ok(())
100    }
101
102    pub(crate) fn start_rollback(&mut self) {
103        if self.transaction_depth == 0 {
104            return;
105        }
106
107        if self.conn.rollback().is_ok() {
108            let _ = self.conn.set_autocommit(true);
109            self.transaction_depth = 0;
110        }
111    }
112
113    pub(crate) const fn transaction_depth(&self) -> usize {
114        self.transaction_depth
115    }
116
117    /// Prepares a statement and returns the metadata reported by the ODBC driver.
118    pub fn prepare_blocking(
119        &mut self,
120        sql: sqlx_core::sql_str::SqlStr,
121    ) -> std::result::Result<OdbcStatement, sqlx_core::Error> {
122        let mut prepared = self
123            .conn
124            .prepare(sql.as_str())
125            .map_err(crate::OdbcError::from)?;
126        let parameters = prepared.num_params().map_err(crate::OdbcError::from)?;
127        let columns = collect_prepared_columns(&mut prepared, parameters)?;
128
129        Ok(OdbcStatement::new(sql, columns, usize::from(parameters)))
130    }
131
132    pub(crate) fn run_blocking_sql(
133        &mut self,
134        sql: &str,
135        arguments: Option<&OdbcArguments>,
136    ) -> std::result::Result<OdbcExecution, sqlx_core::Error> {
137        let mut statement = self.conn.preallocate().map_err(crate::OdbcError::from)?;
138        let parameters = odbc_parameters(arguments);
139
140        if let Some(cursor) = statement
141            .execute(sql, parameters.as_slice())
142            .map_err(crate::OdbcError::from)?
143        {
144            return collect_rows(cursor, self.buffer_settings).map(OdbcExecution::Rows);
145        }
146
147        let rows_affected = statement
148            .row_count()
149            .map_err(crate::OdbcError::from)?
150            .unwrap_or(0);
151
152        let rows_affected = rows_affected.try_into().map_err(|_| {
153            sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned())
154        })?;
155
156        Ok(OdbcExecution::Done(OdbcQueryResult::new(rows_affected)))
157    }
158}
159
160impl sqlx_core::connection::Connection for OdbcConnection {
161    type Database = crate::Odbc;
162    type Options = OdbcConnectOptions;
163
164    async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
165        drop(self);
166        Ok(())
167    }
168
169    async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
170        drop(self);
171        Ok(())
172    }
173
174    async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
175        self.ping_blocking().map_err(Into::into)
176    }
177
178    fn begin(
179        &mut self,
180    ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
181           + Send
182           + '_ {
183        Transaction::begin(self, None)
184    }
185
186    fn shrink_buffers(&mut self) {}
187
188    async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
189        Ok(())
190    }
191
192    fn should_flush(&self) -> bool {
193        false
194    }
195}
196
197impl<'c> Executor<'c> for &'c mut OdbcConnection {
198    type Database = crate::Odbc;
199
200    fn fetch_many<'e, 'q, E>(
201        self,
202        mut query: E,
203    ) -> BoxStream<'e, std::result::Result<Either<OdbcQueryResult, OdbcRow>, sqlx_core::Error>>
204    where
205        'c: 'e,
206        E: Execute<'q, Self::Database>,
207        'q: 'e,
208        E: 'q,
209    {
210        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
211        let sql = query.sql();
212
213        stream::once(async move {
214            let arguments = arguments?;
215            self.run_blocking_sql(sql.as_str(), arguments.as_ref())
216        })
217        .map(|result| match result {
218            Ok(OdbcExecution::Done(result)) => {
219                stream::once(future::ready(Ok(Either::Left(result)))).boxed()
220            }
221            Ok(OdbcExecution::Rows(rows)) => stream::iter(
222                rows.into_iter()
223                    .map(|row| Ok(Either::Right(row)))
224                    .chain(std::iter::once(Ok(Either::Left(OdbcQueryResult::new(0))))),
225            )
226            .boxed(),
227            Err(error) => stream::once(future::ready(Err(error))).boxed(),
228        })
229        .flatten()
230        .boxed()
231    }
232
233    fn fetch_optional<'e, 'q, E>(
234        self,
235        mut query: E,
236    ) -> BoxFuture<'e, std::result::Result<Option<OdbcRow>, sqlx_core::Error>>
237    where
238        'c: 'e,
239        E: Execute<'q, Self::Database>,
240        'q: 'e,
241        E: 'q,
242    {
243        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
244        let sql = query.sql();
245
246        Box::pin(async move {
247            let arguments = arguments?;
248
249            match self.run_blocking_sql(sql.as_str(), arguments.as_ref())? {
250                OdbcExecution::Rows(rows) => Ok(rows.into_iter().next()),
251                OdbcExecution::Done(_) => Ok(None),
252            }
253        })
254    }
255
256    fn prepare_with<'e>(
257        self,
258        sql: sqlx_core::sql_str::SqlStr,
259        _parameters: &[crate::OdbcTypeInfo],
260    ) -> BoxFuture<'e, std::result::Result<OdbcStatement, sqlx_core::Error>>
261    where
262        'c: 'e,
263    {
264        Box::pin(async move { self.prepare_blocking(sql) })
265    }
266}
267
268pub(crate) enum OdbcExecution {
269    Done(OdbcQueryResult),
270    Rows(Vec<OdbcRow>),
271}
272
273fn odbc_parameters(arguments: Option<&OdbcArguments>) -> OdbcParameterCollection {
274    arguments
275        .map(OdbcArguments::to_odbc_parameter_collection)
276        .unwrap_or_default()
277}
278
279fn ping_query_for_dbms_name(dbms_name: &str) -> &'static str {
280    let dbms_name = dbms_name.to_ascii_uppercase();
281
282    if dbms_name.contains("DB2")
283        || dbms_name.contains("DB/2")
284        || dbms_name.contains("ISERIES")
285        || dbms_name.contains("AS/400")
286        || dbms_name.contains("IBM I")
287    {
288        "SELECT 1 FROM SYSIBM.SYSDUMMY1"
289    } else {
290        "SELECT 1"
291    }
292}
293
294fn collect_columns(
295    cursor: &mut impl ResultSetMetadata,
296) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
297    let count = cursor.num_result_cols().map_err(crate::OdbcError::from)?;
298    let count = usize::try_from(count).map_err(|_| {
299        sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
300    })?;
301
302    let mut columns = Vec::with_capacity(count);
303    for ordinal in 0..count {
304        let column_number = u16::try_from(ordinal + 1).map_err(|_| {
305            sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
306        })?;
307
308        let mut description = odbc_api::ColumnDescription::default();
309        cursor
310            .describe_col(column_number, &mut description)
311            .map_err(crate::OdbcError::from)?;
312        let name = description
313            .name_to_string()
314            .unwrap_or_else(|_| format!("col{ordinal}"));
315
316        columns.push(OdbcColumn::new(
317            ordinal,
318            name,
319            OdbcTypeInfo::new(description.data_type),
320        ));
321    }
322
323    Ok(columns)
324}
325
326fn collect_prepared_columns(
327    prepared: &mut impl PreparedStatementMetadata,
328    parameter_count: u16,
329) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
330    match collect_columns(prepared) {
331        Ok(columns) => Ok(columns),
332        Err(error) if parameter_count > 0 => {
333            validate_parameter_metadata(prepared, parameter_count)?;
334            log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
335            Ok(Vec::new())
336        }
337        Err(error) => Err(error),
338    }
339}
340
341trait PreparedStatementMetadata: ResultSetMetadata {
342    fn describe_prepared_parameter(
343        &mut self,
344        index: u16,
345    ) -> std::result::Result<(), odbc_api::Error>;
346}
347
348impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
349where
350    S: odbc_api::handles::AsStatementRef,
351{
352    fn describe_prepared_parameter(
353        &mut self,
354        index: u16,
355    ) -> std::result::Result<(), odbc_api::Error> {
356        self.describe_param(index).map(|_| ())
357    }
358}
359
360fn validate_parameter_metadata(
361    prepared: &mut impl PreparedStatementMetadata,
362    parameter_count: u16,
363) -> std::result::Result<(), sqlx_core::Error> {
364    for index in 1..=parameter_count {
365        prepared
366            .describe_prepared_parameter(index)
367            .map_err(crate::OdbcError::from)?;
368    }
369
370    Ok(())
371}
372
373fn collect_rows<C>(
374    cursor: C,
375    settings: OdbcBufferSettings,
376) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
377where
378    C: Cursor + ResultSetMetadata,
379{
380    if let Some(max_column_size) = settings.max_column_size {
381        collect_rows_buffered(cursor, settings.batch_size, max_column_size)
382    } else {
383        collect_rows_unbuffered(cursor)
384    }
385}
386
387#[derive(Debug)]
388struct ColumnBinding {
389    column: OdbcColumn,
390    buffer_desc: BufferDesc,
391}
392
393fn collect_rows_buffered<C>(
394    cursor: C,
395    batch_size: usize,
396    max_column_size: usize,
397) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
398where
399    C: Cursor + ResultSetMetadata,
400{
401    let mut cursor = cursor;
402    let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
403    let buffer_descriptions = bindings
404        .iter()
405        .map(|binding| binding.buffer_desc)
406        .collect::<Vec<_>>();
407    let mut row_set_cursor = cursor
408        .bind_buffer(ColumnarDynBuffer::from_descs(
409            batch_size,
410            buffer_descriptions,
411        ))
412        .map_err(|error| {
413            crate::error::database_error_with_context(
414                error,
415                format!(
416                    "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
417                     this driver may reject the row-array or row-binding statement attributes \
418                     used for column-wise buffered fetching, so use \
419                     OdbcConnectOptions::max_column_size(None) to fetch rows unbuffered"
420                ),
421            )
422        })?;
423    let columns = bindings
424        .iter()
425        .map(|binding| binding.column.clone())
426        .collect::<Vec<_>>();
427    let mut rows = Vec::new();
428
429    while let Some(batch) = row_set_cursor.fetch().map_err(crate::OdbcError::from)? {
430        let column_values = bindings
431            .iter()
432            .enumerate()
433            .map(|(index, binding)| {
434                buffered_column_values(batch.column(index), binding.buffer_desc)
435            })
436            .collect::<std::result::Result<Vec<_>, _>>()?;
437
438        for row_index in 0..batch.num_rows() {
439            let values = column_values
440                .iter()
441                .map(|values| OdbcValue::new(values[row_index].clone()))
442                .collect::<Vec<_>>();
443            rows.push(OdbcRow::new(columns.clone(), values));
444        }
445    }
446
447    Ok(rows)
448}
449
450fn build_buffer_bindings(
451    cursor: &mut impl ResultSetMetadata,
452    max_column_size: usize,
453) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
454    collect_columns(cursor).map(|columns| {
455        columns
456            .into_iter()
457            .map(|column| ColumnBinding {
458                buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
459                column,
460            })
461            .collect()
462    })
463}
464
465fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
466    match data_type {
467        DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
468            BufferDesc::I64 { nullable: true }
469        }
470        DataType::Real => BufferDesc::F32 { nullable: true },
471        DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
472        DataType::Bit => BufferDesc::Bit { nullable: true },
473        DataType::Date => BufferDesc::Date { nullable: true },
474        DataType::Time { .. } => BufferDesc::Time { nullable: true },
475        DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
476        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
477            BufferDesc::Binary {
478                max_bytes: max_column_size,
479            }
480        }
481        DataType::Char { .. }
482        | DataType::WChar { .. }
483        | DataType::Varchar { .. }
484        | DataType::WVarchar { .. }
485        | DataType::LongVarchar { .. }
486        | DataType::WLongVarchar { .. }
487        | DataType::Other { .. }
488        | DataType::Unknown
489        | DataType::Decimal { .. }
490        | DataType::Numeric { .. } => BufferDesc::Text {
491            max_str_len: max_column_size,
492        },
493    }
494}
495
496fn buffered_column_values(
497    slice: AnyColumnBufferSlice<'_>,
498    desc: BufferDesc,
499) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error> {
500    Ok(match desc {
501        BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
502            OdbcValueKind::TinyInt(value)
503        })?,
504        BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
505            OdbcValueKind::SmallInt(value)
506        })?,
507        BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
508            OdbcValueKind::Integer(value)
509        })?,
510        BufferDesc::I64 { nullable } => {
511            buffered_numeric(&slice, desc, nullable, OdbcValueKind::BigInt)?
512        }
513        BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
514            OdbcValueKind::BigInt(i64::from(value))
515        })?,
516        BufferDesc::F32 { nullable } => {
517            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Real)?
518        }
519        BufferDesc::F64 { nullable } => {
520            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Double)?
521        }
522        BufferDesc::Bit { nullable } => {
523            buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
524                OdbcValueKind::Bit(value.as_bool())
525            })?
526        }
527        BufferDesc::Date { nullable } => {
528            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Date)?
529        }
530        BufferDesc::Time { nullable } => {
531            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Time)?
532        }
533        BufferDesc::Timestamp { nullable } => {
534            buffered_numeric(&slice, desc, nullable, OdbcValueKind::Timestamp)?
535        }
536        BufferDesc::Text { .. } => {
537            let text = expect_buffer_slice(slice.as_text(), desc)?;
538            text.iter()
539                .map(|value| {
540                    value
541                        .map(|bytes| {
542                            OdbcValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
543                        })
544                        .unwrap_or(OdbcValueKind::Null)
545                })
546                .collect()
547        }
548        BufferDesc::WText { .. } => {
549            let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
550            text.iter()
551                .map(|value| {
552                    value
553                        .map(|chars| OdbcValueKind::Text(String::from_utf16_lossy(chars.into())))
554                        .unwrap_or(OdbcValueKind::Null)
555                })
556                .collect()
557        }
558        BufferDesc::Binary { .. } => {
559            let binary = expect_buffer_slice(slice.as_binary(), desc)?;
560            binary
561                .iter()
562                .map(|value| {
563                    value
564                        .map(|bytes| OdbcValueKind::Binary(bytes.to_vec()))
565                        .unwrap_or(OdbcValueKind::Null)
566                })
567                .collect()
568        }
569        BufferDesc::Numeric => {
570            return Err(sqlx_core::Error::Protocol(format!(
571                "unsupported ODBC buffer descriptor: {desc:?}"
572            )))
573        }
574    })
575}
576
577fn buffered_numeric<T, F>(
578    slice: &AnyColumnBufferSlice<'_>,
579    desc: BufferDesc,
580    nullable: bool,
581    map: F,
582) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error>
583where
584    T: Copy + odbc_api::Pod,
585    F: FnMut(T) -> OdbcValueKind,
586{
587    if nullable {
588        Ok(buffered_nullable_numeric(
589            expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
590            map,
591        ))
592    } else {
593        Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
594            .iter()
595            .copied()
596            .map(map)
597            .collect())
598    }
599}
600
601fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<OdbcValueKind>
602where
603    T: Copy,
604    F: FnMut(T) -> OdbcValueKind,
605{
606    slice
607        .map(|value| value.copied().map(&mut map).unwrap_or(OdbcValueKind::Null))
608        .collect()
609}
610
611fn expect_buffer_slice<T>(
612    slice: Option<T>,
613    desc: BufferDesc,
614) -> std::result::Result<T, sqlx_core::Error> {
615    slice.ok_or_else(|| {
616        sqlx_core::Error::Protocol(format!(
617            "ODBC column buffer {desc:?} did not match fetched slice"
618        ))
619    })
620}
621
622fn collect_rows_unbuffered<C>(mut cursor: C) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
623where
624    C: Cursor + ResultSetMetadata,
625{
626    let columns = collect_columns(&mut cursor)?;
627    let mut rows = Vec::new();
628
629    while let Some(mut cursor_row) = cursor.next_row().map_err(crate::OdbcError::from)? {
630        let mut values = Vec::with_capacity(columns.len());
631
632        for column in &columns {
633            let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
634                .map_err(|_| {
635                    sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
636                })?;
637            values.push(fetch_value(
638                &mut cursor_row,
639                column_number,
640                column.type_info().data_type(),
641            )?);
642        }
643
644        rows.push(OdbcRow::new(columns.clone(), values));
645    }
646
647    Ok(rows)
648}
649
650fn fetch_value(
651    row: &mut odbc_api::CursorRow<'_>,
652    column_number: u16,
653    data_type: DataType,
654) -> std::result::Result<OdbcValue, sqlx_core::Error> {
655    let kind = match data_type {
656        DataType::Bit => {
657            let mut value = Nullable::<odbc_api::Bit>::null();
658            row.get_data(column_number, &mut value)
659                .map_err(crate::OdbcError::from)?;
660            value
661                .into_opt()
662                .map(|value| OdbcValueKind::Bit(value.as_bool()))
663                .unwrap_or(OdbcValueKind::Null)
664        }
665        DataType::TinyInt => fetch_nullable(row, column_number, OdbcValueKind::TinyInt)?,
666        DataType::SmallInt => fetch_nullable(row, column_number, OdbcValueKind::SmallInt)?,
667        DataType::Integer => fetch_nullable(row, column_number, OdbcValueKind::Integer)?,
668        DataType::BigInt => fetch_nullable(row, column_number, OdbcValueKind::BigInt)?,
669        DataType::Real => fetch_nullable(row, column_number, OdbcValueKind::Real)?,
670        DataType::Float { .. } | DataType::Double => {
671            fetch_nullable(row, column_number, OdbcValueKind::Double)?
672        }
673        DataType::Date => fetch_nullable(row, column_number, OdbcValueKind::Date)?,
674        DataType::Time { .. } => fetch_nullable(row, column_number, OdbcValueKind::Time)?,
675        DataType::Timestamp { .. } => fetch_nullable(row, column_number, OdbcValueKind::Timestamp)?,
676        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
677            let mut value = Vec::new();
678            if row
679                .get_binary(column_number, &mut value)
680                .map_err(crate::OdbcError::from)?
681            {
682                OdbcValueKind::Binary(value)
683            } else {
684                OdbcValueKind::Null
685            }
686        }
687        _ => {
688            let mut value = Vec::new();
689            if row
690                .get_wide_text(column_number, &mut value)
691                .map_err(crate::OdbcError::from)?
692            {
693                OdbcValueKind::Text(String::from_utf16_lossy(&value))
694            } else {
695                OdbcValueKind::Null
696            }
697        }
698    };
699
700    Ok(OdbcValue::new(kind))
701}
702
703fn fetch_nullable<T, F>(
704    row: &mut odbc_api::CursorRow<'_>,
705    column_number: u16,
706    map: F,
707) -> std::result::Result<OdbcValueKind, sqlx_core::Error>
708where
709    T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
710    Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
711    F: FnOnce(T) -> OdbcValueKind,
712{
713    let mut value = Nullable::<T>::null();
714    row.get_data(column_number, &mut value)
715        .map_err(crate::OdbcError::from)?;
716    Ok(value.into_opt().map(map).unwrap_or(OdbcValueKind::Null))
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    #[test]
724    fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
725        assert!(matches!(
726            map_buffer_desc(DataType::TinyInt, 64),
727            BufferDesc::I64 { nullable: true }
728        ));
729        assert!(matches!(
730            map_buffer_desc(DataType::Integer, 64),
731            BufferDesc::I64 { nullable: true }
732        ));
733        assert!(matches!(
734            map_buffer_desc(DataType::BigInt, 64),
735            BufferDesc::I64 { nullable: true }
736        ));
737    }
738
739    #[test]
740    fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
741        assert_eq!(
742            map_buffer_desc(DataType::Varchar { length: None }, 32),
743            BufferDesc::Text { max_str_len: 32 }
744        );
745        assert_eq!(
746            map_buffer_desc(DataType::Varbinary { length: None }, 16),
747            BufferDesc::Binary { max_bytes: 16 }
748        );
749    }
750
751    #[test]
752    fn ping_query_uses_db2_dummy_table_for_db2_drivers() {
753        assert_eq!(
754            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
755            ping_query_for_dbms_name("DB2")
756        );
757        assert_eq!(
758            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
759            ping_query_for_dbms_name("DB2 UDB for AS/400")
760        );
761        assert_eq!(
762            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
763            ping_query_for_dbms_name("IBM DB2 for i")
764        );
765        assert_eq!(
766            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
767            ping_query_for_dbms_name("iSeries")
768        );
769        assert_eq!(
770            "SELECT 1 FROM SYSIBM.SYSDUMMY1",
771            ping_query_for_dbms_name("IBM i")
772        );
773    }
774
775    #[test]
776    fn ping_query_keeps_select_one_for_non_db2_drivers() {
777        assert_eq!("SELECT 1", ping_query_for_dbms_name("DuckDB"));
778        assert_eq!("SELECT 1", ping_query_for_dbms_name("Microsoft SQL Server"));
779        assert_eq!("SELECT 1", ping_query_for_dbms_name("PostgreSQL"));
780    }
781}