1use crate::{OdbcConnection, Result};
2use log::LevelFilter;
3use std::fmt::{self, Debug, Formatter};
4use std::str::FromStr;
5use std::time::Duration;
6use url::Url;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct OdbcBufferSettings {
15 pub batch_size: usize,
17 pub max_column_size: Option<usize>,
19}
20
21impl Default for OdbcBufferSettings {
22 fn default() -> Self {
23 Self {
24 batch_size: 64,
25 max_column_size: None,
26 }
27 }
28}
29
30#[derive(Clone)]
32pub struct OdbcConnectOptions {
33 pub(crate) conn_str: String,
34 pub(crate) buffer_settings: OdbcBufferSettings,
35 pub(crate) log_statements: LevelFilter,
36 pub(crate) log_slow_statements: LevelFilter,
37 pub(crate) log_slow_statement_duration: Duration,
38}
39
40impl OdbcConnectOptions {
41 pub fn connection_string(&self) -> &str {
43 &self.conn_str
44 }
45
46 pub fn buffer_settings(&mut self, settings: OdbcBufferSettings) -> &mut Self {
48 assert!(settings.batch_size > 0, "batch_size must be greater than 0");
49 if let Some(size) = settings.max_column_size {
50 assert!(size > 0, "max_column_size must be greater than 0");
51 }
52
53 self.buffer_settings = settings;
54 self
55 }
56
57 pub fn buffer_settings_ref(&self) -> &OdbcBufferSettings {
59 &self.buffer_settings
60 }
61
62 pub fn batch_size(&mut self, batch_size: usize) -> &mut Self {
64 assert!(batch_size > 0, "batch_size must be greater than 0");
65 self.buffer_settings.batch_size = batch_size;
66 self
67 }
68
69 pub fn max_column_size(&mut self, max_column_size: Option<usize>) -> &mut Self {
71 if let Some(size) = max_column_size {
72 assert!(size > 0, "max_column_size must be greater than 0");
73 }
74
75 self.buffer_settings.max_column_size = max_column_size;
76 self
77 }
78
79 pub fn log_statements(&mut self, level: LevelFilter) -> &mut Self {
81 self.log_statements = level;
82 self
83 }
84
85 pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self {
87 self.log_slow_statements = level;
88 self.log_slow_statement_duration = duration;
89 self
90 }
91
92 pub fn connect_blocking(&self) -> Result<OdbcConnection> {
96 OdbcConnection::connect_blocking(self)
97 }
98}
99
100impl Debug for OdbcConnectOptions {
101 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102 f.debug_struct("OdbcConnectOptions")
103 .field("conn_str", &"<redacted>")
104 .field("buffer_settings", &self.buffer_settings)
105 .field("log_statements", &self.log_statements)
106 .field("log_slow_statements", &self.log_slow_statements)
107 .field(
108 "log_slow_statement_duration",
109 &self.log_slow_statement_duration,
110 )
111 .finish()
112 }
113}
114
115impl FromStr for OdbcConnectOptions {
116 type Err = sqlx_core::Error;
117
118 fn from_str(input: &str) -> std::result::Result<Self, Self::Err> {
119 let mut trimmed = input.trim();
120
121 if let Some(rest) = trimmed.strip_prefix("odbc:") {
122 trimmed = rest;
123 }
124
125 let conn_str = if trimmed.contains('=') {
126 trimmed.to_owned()
127 } else {
128 format!("DSN={trimmed}")
129 };
130
131 Ok(Self {
132 conn_str,
133 buffer_settings: OdbcBufferSettings::default(),
134 log_statements: LevelFilter::Debug,
135 log_slow_statements: LevelFilter::Warn,
136 log_slow_statement_duration: Duration::from_secs(1),
137 })
138 }
139}
140
141impl sqlx_core::connection::ConnectOptions for OdbcConnectOptions {
142 type Connection = OdbcConnection;
143
144 fn from_url(url: &Url) -> std::result::Result<Self, sqlx_core::Error> {
145 Self::from_str(url.as_str())
146 }
147
148 async fn connect(&self) -> std::result::Result<Self::Connection, sqlx_core::Error> {
149 self.connect_blocking().map_err(Into::into)
150 }
151
152 fn log_statements(mut self, level: LevelFilter) -> Self {
153 self.log_statements = level;
154 self
155 }
156
157 fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
158 self.log_slow_statements = level;
159 self.log_slow_statement_duration = duration;
160 self
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn parses_bare_dsn_as_dsn_connection_string() {
170 let options = OdbcConnectOptions::from_str("Warehouse").unwrap();
171 assert_eq!(options.connection_string(), "DSN=Warehouse");
172 }
173
174 #[test]
175 fn preserves_standard_connection_strings() {
176 let input = "Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test";
177 let options = OdbcConnectOptions::from_str(input).unwrap();
178 assert_eq!(options.connection_string(), input);
179 }
180
181 #[test]
182 fn strips_legacy_odbc_prefix() {
183 let options = OdbcConnectOptions::from_str("odbc:DSN=Warehouse").unwrap();
184 assert_eq!(options.connection_string(), "DSN=Warehouse");
185 }
186
187 #[test]
188 fn updates_buffer_settings_incrementally() {
189 let mut options = OdbcConnectOptions::from_str("Warehouse").unwrap();
190 options.batch_size(128).max_column_size(Some(2048));
191
192 assert_eq!(
193 *options.buffer_settings_ref(),
194 OdbcBufferSettings {
195 batch_size: 128,
196 max_column_size: Some(2048)
197 }
198 );
199 }
200}