1use crate::{
2 OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcParameterCollection,
3 OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo, OdbcValue, OdbcValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::executor::{Execute, Executor};
12use sqlx_core::transaction::Transaction;
13use sqlx_core::Either;
14use std::future::Future;
15
16pub struct OdbcConnection {
21 conn: odbc_api::Connection<'static>,
22 buffer_settings: OdbcBufferSettings,
23 transaction_depth: usize,
24}
25
26impl std::fmt::Debug for OdbcConnection {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("OdbcConnection").finish_non_exhaustive()
29 }
30}
31
32impl OdbcConnection {
33 pub fn connect_blocking(options: &OdbcConnectOptions) -> Result<Self> {
35 let env = odbc_api::environment()
36 .map_err(|error| crate::OdbcError::Configuration(error.to_string()))?;
37
38 let conn =
39 env.connect_with_connection_string(options.connection_string(), Default::default())?;
40
41 Ok(Self {
42 conn,
43 buffer_settings: options.buffer_settings,
44 transaction_depth: 0,
45 })
46 }
47
48 pub fn ping_blocking(&mut self) -> Result<()> {
50 let query = self
51 .conn
52 .database_management_system_name()
53 .map(|name| ping_query_for_dbms_name(&name))
54 .unwrap_or("SELECT 1");
55 self.conn.execute(query, (), None)?;
56 Ok(())
57 }
58
59 pub fn dbms_name(&self) -> Result<String> {
61 Ok(self.conn.database_management_system_name()?)
62 }
63
64 pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
65 if self.transaction_depth > 0 {
66 return Err(sqlx_core::Error::InvalidSavePointStatement);
67 }
68
69 self.conn
70 .set_autocommit(false)
71 .map_err(crate::OdbcError::from)?;
72 self.transaction_depth = 1;
73 Ok(())
74 }
75
76 pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
77 if self.transaction_depth == 0 {
78 return Ok(());
79 }
80
81 self.conn.commit().map_err(crate::OdbcError::from)?;
82 self.conn
83 .set_autocommit(true)
84 .map_err(crate::OdbcError::from)?;
85 self.transaction_depth = 0;
86 Ok(())
87 }
88
89 pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
90 if self.transaction_depth == 0 {
91 return Ok(());
92 }
93
94 self.conn.rollback().map_err(crate::OdbcError::from)?;
95 self.conn
96 .set_autocommit(true)
97 .map_err(crate::OdbcError::from)?;
98 self.transaction_depth = 0;
99 Ok(())
100 }
101
102 pub(crate) fn start_rollback(&mut self) {
103 if self.transaction_depth == 0 {
104 return;
105 }
106
107 if self.conn.rollback().is_ok() {
108 let _ = self.conn.set_autocommit(true);
109 self.transaction_depth = 0;
110 }
111 }
112
113 pub(crate) const fn transaction_depth(&self) -> usize {
114 self.transaction_depth
115 }
116
117 pub fn prepare_blocking(
119 &mut self,
120 sql: sqlx_core::sql_str::SqlStr,
121 ) -> std::result::Result<OdbcStatement, sqlx_core::Error> {
122 let mut prepared = self
123 .conn
124 .prepare(sql.as_str())
125 .map_err(crate::OdbcError::from)?;
126 let parameters = prepared.num_params().map_err(crate::OdbcError::from)?;
127 let columns = collect_prepared_columns(&mut prepared, parameters)?;
128
129 Ok(OdbcStatement::new(sql, columns, usize::from(parameters)))
130 }
131
132 pub(crate) fn run_blocking_sql(
133 &mut self,
134 sql: &str,
135 arguments: Option<&OdbcArguments>,
136 ) -> std::result::Result<OdbcExecution, sqlx_core::Error> {
137 let mut statement = self.conn.preallocate().map_err(crate::OdbcError::from)?;
138 let parameters = odbc_parameters(arguments);
139
140 if let Some(cursor) = statement
141 .execute(sql, parameters.as_slice())
142 .map_err(crate::OdbcError::from)?
143 {
144 return collect_rows(cursor, self.buffer_settings).map(OdbcExecution::Rows);
145 }
146
147 let rows_affected = statement
148 .row_count()
149 .map_err(crate::OdbcError::from)?
150 .unwrap_or(0);
151
152 let rows_affected = rows_affected.try_into().map_err(|_| {
153 sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned())
154 })?;
155
156 Ok(OdbcExecution::Done(OdbcQueryResult::new(rows_affected)))
157 }
158}
159
160impl sqlx_core::connection::Connection for OdbcConnection {
161 type Database = crate::Odbc;
162 type Options = OdbcConnectOptions;
163
164 async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
165 drop(self);
166 Ok(())
167 }
168
169 async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
170 drop(self);
171 Ok(())
172 }
173
174 async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
175 self.ping_blocking().map_err(Into::into)
176 }
177
178 fn begin(
179 &mut self,
180 ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
181 + Send
182 + '_ {
183 Transaction::begin(self, None)
184 }
185
186 fn shrink_buffers(&mut self) {}
187
188 async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
189 Ok(())
190 }
191
192 fn should_flush(&self) -> bool {
193 false
194 }
195}
196
197impl<'c> Executor<'c> for &'c mut OdbcConnection {
198 type Database = crate::Odbc;
199
200 fn fetch_many<'e, 'q, E>(
201 self,
202 mut query: E,
203 ) -> BoxStream<'e, std::result::Result<Either<OdbcQueryResult, OdbcRow>, sqlx_core::Error>>
204 where
205 'c: 'e,
206 E: Execute<'q, Self::Database>,
207 'q: 'e,
208 E: 'q,
209 {
210 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
211 let sql = query.sql();
212
213 stream::once(async move {
214 let arguments = arguments?;
215 self.run_blocking_sql(sql.as_str(), arguments.as_ref())
216 })
217 .map(|result| match result {
218 Ok(OdbcExecution::Done(result)) => {
219 stream::once(future::ready(Ok(Either::Left(result)))).boxed()
220 }
221 Ok(OdbcExecution::Rows(rows)) => stream::iter(
222 rows.into_iter()
223 .map(|row| Ok(Either::Right(row)))
224 .chain(std::iter::once(Ok(Either::Left(OdbcQueryResult::new(0))))),
225 )
226 .boxed(),
227 Err(error) => stream::once(future::ready(Err(error))).boxed(),
228 })
229 .flatten()
230 .boxed()
231 }
232
233 fn fetch_optional<'e, 'q, E>(
234 self,
235 mut query: E,
236 ) -> BoxFuture<'e, std::result::Result<Option<OdbcRow>, sqlx_core::Error>>
237 where
238 'c: 'e,
239 E: Execute<'q, Self::Database>,
240 'q: 'e,
241 E: 'q,
242 {
243 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
244 let sql = query.sql();
245
246 Box::pin(async move {
247 let arguments = arguments?;
248
249 match self.run_blocking_sql(sql.as_str(), arguments.as_ref())? {
250 OdbcExecution::Rows(rows) => Ok(rows.into_iter().next()),
251 OdbcExecution::Done(_) => Ok(None),
252 }
253 })
254 }
255
256 fn prepare_with<'e>(
257 self,
258 sql: sqlx_core::sql_str::SqlStr,
259 _parameters: &[crate::OdbcTypeInfo],
260 ) -> BoxFuture<'e, std::result::Result<OdbcStatement, sqlx_core::Error>>
261 where
262 'c: 'e,
263 {
264 Box::pin(async move { self.prepare_blocking(sql) })
265 }
266}
267
268pub(crate) enum OdbcExecution {
269 Done(OdbcQueryResult),
270 Rows(Vec<OdbcRow>),
271}
272
273fn odbc_parameters(arguments: Option<&OdbcArguments>) -> OdbcParameterCollection {
274 arguments
275 .map(OdbcArguments::to_odbc_parameter_collection)
276 .unwrap_or_default()
277}
278
279fn ping_query_for_dbms_name(dbms_name: &str) -> &'static str {
280 let dbms_name = dbms_name.to_ascii_uppercase();
281
282 if dbms_name.contains("DB2")
283 || dbms_name.contains("DB/2")
284 || dbms_name.contains("ISERIES")
285 || dbms_name.contains("AS/400")
286 || dbms_name.contains("IBM I")
287 {
288 "SELECT 1 FROM SYSIBM.SYSDUMMY1"
289 } else {
290 "SELECT 1"
291 }
292}
293
294fn collect_columns(
295 cursor: &mut impl ResultSetMetadata,
296) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
297 let count = cursor.num_result_cols().map_err(crate::OdbcError::from)?;
298 let count = usize::try_from(count).map_err(|_| {
299 sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
300 })?;
301
302 let mut columns = Vec::with_capacity(count);
303 for ordinal in 0..count {
304 let column_number = u16::try_from(ordinal + 1).map_err(|_| {
305 sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
306 })?;
307
308 let mut description = odbc_api::ColumnDescription::default();
309 cursor
310 .describe_col(column_number, &mut description)
311 .map_err(crate::OdbcError::from)?;
312 let name = description
313 .name_to_string()
314 .unwrap_or_else(|_| format!("col{ordinal}"));
315
316 columns.push(OdbcColumn::new(
317 ordinal,
318 name,
319 OdbcTypeInfo::new(description.data_type),
320 ));
321 }
322
323 Ok(columns)
324}
325
326fn collect_prepared_columns(
327 prepared: &mut impl PreparedStatementMetadata,
328 parameter_count: u16,
329) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
330 match collect_columns(prepared) {
331 Ok(columns) => Ok(columns),
332 Err(error) if parameter_count > 0 => {
333 validate_parameter_metadata(prepared, parameter_count)?;
334 log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
335 Ok(Vec::new())
336 }
337 Err(error) => Err(error),
338 }
339}
340
341trait PreparedStatementMetadata: ResultSetMetadata {
342 fn describe_prepared_parameter(
343 &mut self,
344 index: u16,
345 ) -> std::result::Result<(), odbc_api::Error>;
346}
347
348impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
349where
350 S: odbc_api::handles::AsStatementRef,
351{
352 fn describe_prepared_parameter(
353 &mut self,
354 index: u16,
355 ) -> std::result::Result<(), odbc_api::Error> {
356 self.describe_param(index).map(|_| ())
357 }
358}
359
360fn validate_parameter_metadata(
361 prepared: &mut impl PreparedStatementMetadata,
362 parameter_count: u16,
363) -> std::result::Result<(), sqlx_core::Error> {
364 for index in 1..=parameter_count {
365 prepared
366 .describe_prepared_parameter(index)
367 .map_err(crate::OdbcError::from)?;
368 }
369
370 Ok(())
371}
372
373fn collect_rows<C>(
374 cursor: C,
375 settings: OdbcBufferSettings,
376) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
377where
378 C: Cursor + ResultSetMetadata,
379{
380 if let Some(max_column_size) = settings.max_column_size {
381 collect_rows_buffered(cursor, settings.batch_size, max_column_size)
382 } else {
383 collect_rows_unbuffered(cursor)
384 }
385}
386
387#[derive(Debug)]
388struct ColumnBinding {
389 column: OdbcColumn,
390 buffer_desc: BufferDesc,
391}
392
393fn collect_rows_buffered<C>(
394 cursor: C,
395 batch_size: usize,
396 max_column_size: usize,
397) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
398where
399 C: Cursor + ResultSetMetadata,
400{
401 let mut cursor = cursor;
402 let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
403 let buffer_descriptions = bindings
404 .iter()
405 .map(|binding| binding.buffer_desc)
406 .collect::<Vec<_>>();
407 let mut row_set_cursor = cursor
408 .bind_buffer(ColumnarDynBuffer::from_descs(
409 batch_size,
410 buffer_descriptions,
411 ))
412 .map_err(|error| {
413 crate::error::database_error_with_context(
414 error,
415 format!(
416 "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
417 this driver may reject the row-array or row-binding statement attributes \
418 used for column-wise buffered fetching, so use \
419 OdbcConnectOptions::max_column_size(None) to fetch rows unbuffered"
420 ),
421 )
422 })?;
423 let columns = bindings
424 .iter()
425 .map(|binding| binding.column.clone())
426 .collect::<Vec<_>>();
427 let mut rows = Vec::new();
428
429 while let Some(batch) = row_set_cursor.fetch().map_err(crate::OdbcError::from)? {
430 let column_values = bindings
431 .iter()
432 .enumerate()
433 .map(|(index, binding)| {
434 buffered_column_values(batch.column(index), binding.buffer_desc)
435 })
436 .collect::<std::result::Result<Vec<_>, _>>()?;
437
438 for row_index in 0..batch.num_rows() {
439 let values = column_values
440 .iter()
441 .map(|values| OdbcValue::new(values[row_index].clone()))
442 .collect::<Vec<_>>();
443 rows.push(OdbcRow::new(columns.clone(), values));
444 }
445 }
446
447 Ok(rows)
448}
449
450fn build_buffer_bindings(
451 cursor: &mut impl ResultSetMetadata,
452 max_column_size: usize,
453) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
454 collect_columns(cursor).map(|columns| {
455 columns
456 .into_iter()
457 .map(|column| ColumnBinding {
458 buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
459 column,
460 })
461 .collect()
462 })
463}
464
465fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
466 match data_type {
467 DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
468 BufferDesc::I64 { nullable: true }
469 }
470 DataType::Real => BufferDesc::F32 { nullable: true },
471 DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
472 DataType::Bit => BufferDesc::Bit { nullable: true },
473 DataType::Date => BufferDesc::Date { nullable: true },
474 DataType::Time { .. } => BufferDesc::Time { nullable: true },
475 DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
476 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
477 BufferDesc::Binary {
478 max_bytes: max_column_size,
479 }
480 }
481 DataType::Char { .. }
482 | DataType::WChar { .. }
483 | DataType::Varchar { .. }
484 | DataType::WVarchar { .. }
485 | DataType::LongVarchar { .. }
486 | DataType::WLongVarchar { .. }
487 | DataType::Other { .. }
488 | DataType::Unknown
489 | DataType::Decimal { .. }
490 | DataType::Numeric { .. } => BufferDesc::Text {
491 max_str_len: max_column_size,
492 },
493 }
494}
495
496fn buffered_column_values(
497 slice: AnyColumnBufferSlice<'_>,
498 desc: BufferDesc,
499) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error> {
500 Ok(match desc {
501 BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
502 OdbcValueKind::TinyInt(value)
503 })?,
504 BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
505 OdbcValueKind::SmallInt(value)
506 })?,
507 BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
508 OdbcValueKind::Integer(value)
509 })?,
510 BufferDesc::I64 { nullable } => {
511 buffered_numeric(&slice, desc, nullable, OdbcValueKind::BigInt)?
512 }
513 BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
514 OdbcValueKind::BigInt(i64::from(value))
515 })?,
516 BufferDesc::F32 { nullable } => {
517 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Real)?
518 }
519 BufferDesc::F64 { nullable } => {
520 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Double)?
521 }
522 BufferDesc::Bit { nullable } => {
523 buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
524 OdbcValueKind::Bit(value.as_bool())
525 })?
526 }
527 BufferDesc::Date { nullable } => {
528 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Date)?
529 }
530 BufferDesc::Time { nullable } => {
531 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Time)?
532 }
533 BufferDesc::Timestamp { nullable } => {
534 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Timestamp)?
535 }
536 BufferDesc::Text { .. } => {
537 let text = expect_buffer_slice(slice.as_text(), desc)?;
538 text.iter()
539 .map(|value| {
540 value
541 .map(|bytes| {
542 OdbcValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
543 })
544 .unwrap_or(OdbcValueKind::Null)
545 })
546 .collect()
547 }
548 BufferDesc::WText { .. } => {
549 let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
550 text.iter()
551 .map(|value| {
552 value
553 .map(|chars| OdbcValueKind::Text(String::from_utf16_lossy(chars.into())))
554 .unwrap_or(OdbcValueKind::Null)
555 })
556 .collect()
557 }
558 BufferDesc::Binary { .. } => {
559 let binary = expect_buffer_slice(slice.as_binary(), desc)?;
560 binary
561 .iter()
562 .map(|value| {
563 value
564 .map(|bytes| OdbcValueKind::Binary(bytes.to_vec()))
565 .unwrap_or(OdbcValueKind::Null)
566 })
567 .collect()
568 }
569 BufferDesc::Numeric => {
570 return Err(sqlx_core::Error::Protocol(format!(
571 "unsupported ODBC buffer descriptor: {desc:?}"
572 )))
573 }
574 })
575}
576
577fn buffered_numeric<T, F>(
578 slice: &AnyColumnBufferSlice<'_>,
579 desc: BufferDesc,
580 nullable: bool,
581 map: F,
582) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error>
583where
584 T: Copy + odbc_api::Pod,
585 F: FnMut(T) -> OdbcValueKind,
586{
587 if nullable {
588 Ok(buffered_nullable_numeric(
589 expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
590 map,
591 ))
592 } else {
593 Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
594 .iter()
595 .copied()
596 .map(map)
597 .collect())
598 }
599}
600
601fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<OdbcValueKind>
602where
603 T: Copy,
604 F: FnMut(T) -> OdbcValueKind,
605{
606 slice
607 .map(|value| value.copied().map(&mut map).unwrap_or(OdbcValueKind::Null))
608 .collect()
609}
610
611fn expect_buffer_slice<T>(
612 slice: Option<T>,
613 desc: BufferDesc,
614) -> std::result::Result<T, sqlx_core::Error> {
615 slice.ok_or_else(|| {
616 sqlx_core::Error::Protocol(format!(
617 "ODBC column buffer {desc:?} did not match fetched slice"
618 ))
619 })
620}
621
622fn collect_rows_unbuffered<C>(mut cursor: C) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
623where
624 C: Cursor + ResultSetMetadata,
625{
626 let columns = collect_columns(&mut cursor)?;
627 let mut rows = Vec::new();
628
629 while let Some(mut cursor_row) = cursor.next_row().map_err(crate::OdbcError::from)? {
630 let mut values = Vec::with_capacity(columns.len());
631
632 for column in &columns {
633 let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
634 .map_err(|_| {
635 sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
636 })?;
637 values.push(fetch_value(
638 &mut cursor_row,
639 column_number,
640 column.type_info().data_type(),
641 )?);
642 }
643
644 rows.push(OdbcRow::new(columns.clone(), values));
645 }
646
647 Ok(rows)
648}
649
650fn fetch_value(
651 row: &mut odbc_api::CursorRow<'_>,
652 column_number: u16,
653 data_type: DataType,
654) -> std::result::Result<OdbcValue, sqlx_core::Error> {
655 let kind = match data_type {
656 DataType::Bit => {
657 let mut value = Nullable::<odbc_api::Bit>::null();
658 row.get_data(column_number, &mut value)
659 .map_err(crate::OdbcError::from)?;
660 value
661 .into_opt()
662 .map(|value| OdbcValueKind::Bit(value.as_bool()))
663 .unwrap_or(OdbcValueKind::Null)
664 }
665 DataType::TinyInt => fetch_nullable(row, column_number, OdbcValueKind::TinyInt)?,
666 DataType::SmallInt => fetch_nullable(row, column_number, OdbcValueKind::SmallInt)?,
667 DataType::Integer => fetch_nullable(row, column_number, OdbcValueKind::Integer)?,
668 DataType::BigInt => fetch_nullable(row, column_number, OdbcValueKind::BigInt)?,
669 DataType::Real => fetch_nullable(row, column_number, OdbcValueKind::Real)?,
670 DataType::Float { .. } | DataType::Double => {
671 fetch_nullable(row, column_number, OdbcValueKind::Double)?
672 }
673 DataType::Date => fetch_nullable(row, column_number, OdbcValueKind::Date)?,
674 DataType::Time { .. } => fetch_nullable(row, column_number, OdbcValueKind::Time)?,
675 DataType::Timestamp { .. } => fetch_nullable(row, column_number, OdbcValueKind::Timestamp)?,
676 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
677 let mut value = Vec::new();
678 if row
679 .get_binary(column_number, &mut value)
680 .map_err(crate::OdbcError::from)?
681 {
682 OdbcValueKind::Binary(value)
683 } else {
684 OdbcValueKind::Null
685 }
686 }
687 _ => {
688 let mut value = Vec::new();
689 if row
690 .get_wide_text(column_number, &mut value)
691 .map_err(crate::OdbcError::from)?
692 {
693 OdbcValueKind::Text(String::from_utf16_lossy(&value))
694 } else {
695 OdbcValueKind::Null
696 }
697 }
698 };
699
700 Ok(OdbcValue::new(kind))
701}
702
703fn fetch_nullable<T, F>(
704 row: &mut odbc_api::CursorRow<'_>,
705 column_number: u16,
706 map: F,
707) -> std::result::Result<OdbcValueKind, sqlx_core::Error>
708where
709 T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
710 Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
711 F: FnOnce(T) -> OdbcValueKind,
712{
713 let mut value = Nullable::<T>::null();
714 row.get_data(column_number, &mut value)
715 .map_err(crate::OdbcError::from)?;
716 Ok(value.into_opt().map(map).unwrap_or(OdbcValueKind::Null))
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
725 assert!(matches!(
726 map_buffer_desc(DataType::TinyInt, 64),
727 BufferDesc::I64 { nullable: true }
728 ));
729 assert!(matches!(
730 map_buffer_desc(DataType::Integer, 64),
731 BufferDesc::I64 { nullable: true }
732 ));
733 assert!(matches!(
734 map_buffer_desc(DataType::BigInt, 64),
735 BufferDesc::I64 { nullable: true }
736 ));
737 }
738
739 #[test]
740 fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
741 assert_eq!(
742 map_buffer_desc(DataType::Varchar { length: None }, 32),
743 BufferDesc::Text { max_str_len: 32 }
744 );
745 assert_eq!(
746 map_buffer_desc(DataType::Varbinary { length: None }, 16),
747 BufferDesc::Binary { max_bytes: 16 }
748 );
749 }
750
751 #[test]
752 fn ping_query_uses_db2_dummy_table_for_db2_drivers() {
753 assert_eq!(
754 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
755 ping_query_for_dbms_name("DB2")
756 );
757 assert_eq!(
758 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
759 ping_query_for_dbms_name("DB2 UDB for AS/400")
760 );
761 assert_eq!(
762 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
763 ping_query_for_dbms_name("IBM DB2 for i")
764 );
765 assert_eq!(
766 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
767 ping_query_for_dbms_name("iSeries")
768 );
769 assert_eq!(
770 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
771 ping_query_for_dbms_name("IBM i")
772 );
773 }
774
775 #[test]
776 fn ping_query_keeps_select_one_for_non_db2_drivers() {
777 assert_eq!("SELECT 1", ping_query_for_dbms_name("DuckDB"));
778 assert_eq!("SELECT 1", ping_query_for_dbms_name("Microsoft SQL Server"));
779 assert_eq!("SELECT 1", ping_query_for_dbms_name("PostgreSQL"));
780 }
781}