Skip to main content

sqlx_odbc/
options.rs

1use crate::{OdbcConnection, Result};
2use log::LevelFilter;
3use std::fmt::{self, Debug, Formatter};
4use std::str::FromStr;
5use std::time::Duration;
6use url::Url;
7
8/// Fetch-buffer settings used by the ODBC driver.
9///
10/// `max_column_size = Some(_)` enables buffered fetching and can truncate long text or binary
11/// fields to the configured size. `max_column_size = None` keeps fetching unbuffered so variable
12/// sized values are not truncated by this crate's buffer allocation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct OdbcBufferSettings {
15    /// Number of rows fetched in each batch.
16    pub batch_size: usize,
17    /// Maximum text or binary column size in buffered mode, or `None` for unbuffered mode.
18    pub max_column_size: Option<usize>,
19}
20
21impl Default for OdbcBufferSettings {
22    fn default() -> Self {
23        Self {
24            batch_size: 64,
25            max_column_size: None,
26        }
27    }
28}
29
30/// Connection options for an ODBC data source.
31#[derive(Clone)]
32pub struct OdbcConnectOptions {
33    pub(crate) conn_str: String,
34    pub(crate) buffer_settings: OdbcBufferSettings,
35    pub(crate) log_statements: LevelFilter,
36    pub(crate) log_slow_statements: LevelFilter,
37    pub(crate) log_slow_statement_duration: Duration,
38}
39
40impl OdbcConnectOptions {
41    /// Returns the normalized ODBC connection string.
42    pub fn connection_string(&self) -> &str {
43        &self.conn_str
44    }
45
46    /// Sets the buffer configuration for this connection.
47    pub fn buffer_settings(&mut self, settings: OdbcBufferSettings) -> &mut Self {
48        assert!(settings.batch_size > 0, "batch_size must be greater than 0");
49        if let Some(size) = settings.max_column_size {
50            assert!(size > 0, "max_column_size must be greater than 0");
51        }
52
53        self.buffer_settings = settings;
54        self
55    }
56
57    /// Returns the current buffer settings.
58    pub fn buffer_settings_ref(&self) -> &OdbcBufferSettings {
59        &self.buffer_settings
60    }
61
62    /// Sets the number of rows fetched in each batch.
63    pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
64        assert!(batch_size > 0, "batch_size must be greater than 0");
65        self.buffer_settings.batch_size = batch_size;
66        self
67    }
68
69    /// Sets the maximum buffered column size, or `None` for unbuffered fetching.
70    pub fn max_column_size(&mut self, max_column_size: Option<usize>) -> &mut Self {
71        if let Some(size) = max_column_size {
72            assert!(size > 0, "max_column_size must be greater than 0");
73        }
74
75        self.buffer_settings.max_column_size = max_column_size;
76        self
77    }
78
79    /// Sets regular statement logging level.
80    pub fn log_statements(&mut self, level: LevelFilter) -> &mut Self {
81        self.log_statements = level;
82        self
83    }
84
85    /// Sets slow statement logging level and threshold.
86    pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self {
87        self.log_slow_statements = level;
88        self.log_slow_statement_duration = duration;
89        self
90    }
91
92    /// Opens a blocking ODBC connection.
93    ///
94    /// The full SQLx async connection/executor API will be layered on top of this during the port.
95    pub fn connect_blocking(&self) -> Result<OdbcConnection> {
96        OdbcConnection::connect_blocking(self)
97    }
98}
99
100impl Debug for OdbcConnectOptions {
101    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102        f.debug_struct("OdbcConnectOptions")
103            .field("conn_str", &"<redacted>")
104            .field("buffer_settings", &self.buffer_settings)
105            .field("log_statements", &self.log_statements)
106            .field("log_slow_statements", &self.log_slow_statements)
107            .field(
108                "log_slow_statement_duration",
109                &self.log_slow_statement_duration,
110            )
111            .finish()
112    }
113}
114
115impl FromStr for OdbcConnectOptions {
116    type Err = sqlx_core::Error;
117
118    fn from_str(input: &str) -> std::result::Result<Self, Self::Err> {
119        let mut trimmed = input.trim();
120
121        if let Some(rest) = trimmed.strip_prefix("odbc:") {
122            trimmed = rest;
123        }
124
125        let conn_str = if trimmed.contains('=') {
126            trimmed.to_owned()
127        } else {
128            format!("DSN={trimmed}")
129        };
130
131        Ok(Self {
132            conn_str,
133            buffer_settings: OdbcBufferSettings::default(),
134            log_statements: LevelFilter::Debug,
135            log_slow_statements: LevelFilter::Warn,
136            log_slow_statement_duration: Duration::from_secs(1),
137        })
138    }
139}
140
141impl sqlx_core::connection::ConnectOptions for OdbcConnectOptions {
142    type Connection = OdbcConnection;
143
144    fn from_url(url: &Url) -> std::result::Result<Self, sqlx_core::Error> {
145        Self::from_str(url.as_str())
146    }
147
148    async fn connect(&self) -> std::result::Result<Self::Connection, sqlx_core::Error> {
149        self.connect_blocking().map_err(Into::into)
150    }
151
152    fn log_statements(mut self, level: LevelFilter) -> Self {
153        self.log_statements = level;
154        self
155    }
156
157    fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
158        self.log_slow_statements = level;
159        self.log_slow_statement_duration = duration;
160        self
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn parses_bare_dsn_as_dsn_connection_string() {
170        let options = OdbcConnectOptions::from_str("Warehouse").unwrap();
171        assert_eq!(options.connection_string(), "DSN=Warehouse");
172    }
173
174    #[test]
175    fn preserves_standard_connection_strings() {
176        let input = "Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test";
177        let options = OdbcConnectOptions::from_str(input).unwrap();
178        assert_eq!(options.connection_string(), input);
179    }
180
181    #[test]
182    fn strips_legacy_odbc_prefix() {
183        let options = OdbcConnectOptions::from_str("odbc:DSN=Warehouse").unwrap();
184        assert_eq!(options.connection_string(), "DSN=Warehouse");
185    }
186
187    #[test]
188    fn updates_buffer_settings_incrementally() {
189        let mut options = OdbcConnectOptions::from_str("Warehouse").unwrap();
190        options.batch_size(128).max_column_size(Some(2048));
191
192        assert_eq!(
193            *options.buffer_settings_ref(),
194            OdbcBufferSettings {
195                batch_size: 128,
196                max_column_size: Some(2048)
197            }
198        );
199    }
200}