Skip to main content

tiberius/
client.rs

1mod auth;
2mod config;
3mod connection;
4
5#[cfg(all(windows, feature = "winauth"))]
6mod sspi;
7mod tls;
8#[cfg(any(
9    feature = "rustls",
10    feature = "native-tls",
11    feature = "vendored-openssl"
12))]
13mod tls_stream;
14
15pub use auth::*;
16pub use config::*;
17pub(crate) use connection::*;
18
19use crate::tds::stream::ReceivedToken;
20use crate::{
21    result::ExecuteResult,
22    tds::{
23        codec::{self, IteratorJoin},
24        stream::{QueryStream, TokenStream},
25    },
26    BulkLoadColumns, BulkLoadRequest, ColumnFlag, SqlReadBytes, ToSql,
27};
28use codec::{BatchRequest, ColumnData, PacketHeader, RpcParam, RpcProcId, TokenRpcRequest};
29use enumflags2::BitFlags;
30use futures_util::io::{AsyncRead, AsyncWrite};
31use futures_util::stream::TryStreamExt;
32use std::{borrow::Cow, fmt::Debug};
33
34/// `Client` is the main entry point to the SQL Server, providing query
35/// execution capabilities.
36///
37/// A `Client` is created using the [`Config`], defining the needed
38/// connection options and capabilities.
39///
40/// # Example
41///
42/// ```no_run
43/// # use tiberius::{Config, AuthMethod};
44/// use tokio_util::compat::TokioAsyncWriteCompatExt;
45///
46/// # #[tokio::main]
47/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
48/// let mut config = Config::new();
49///
50/// config.host("0.0.0.0");
51/// config.port(1433);
52/// config.authentication(AuthMethod::sql_server("SA", "<Mys3cureP4ssW0rD>"));
53///
54/// let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
55/// tcp.set_nodelay(true)?;
56/// // Client is ready to use.
57/// let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
58/// # Ok(())
59/// # }
60/// ```
61///
62/// [`Config`]: struct.Config.html
63#[derive(Debug)]
64pub struct Client<S: AsyncRead + AsyncWrite + Unpin + Send> {
65    pub(crate) connection: Connection<S>,
66}
67
68impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
69    /// Uses an instance of [`Config`] to specify the connection
70    /// options required to connect to the database using an established
71    /// tcp connection
72    ///
73    /// [`Config`]: struct.Config.html
74    pub async fn connect(config: Config, tcp_stream: S) -> crate::Result<Client<S>> {
75        Ok(Client {
76            connection: Connection::connect(config, tcp_stream).await?,
77        })
78    }
79
80    /// Executes SQL statements in the SQL Server, returning the number rows
81    /// affected. Useful for `INSERT`, `UPDATE` and `DELETE` statements. The
82    /// `query` can define the parameter placement by annotating them with
83    /// `@PN`, where N is the index of the parameter, starting from `1`. If
84    /// executing multiple queries at a time, delimit them with `;` and refer to
85    /// [`ExecuteResult`] how to get results for the separate queries.
86    ///
87    /// For mapping of Rust types when writing, see the documentation for
88    /// [`ToSql`]. For reading data from the database, see the documentation for
89    /// [`FromSql`].
90    ///
91    /// This API is not quite suitable for dynamic query parameters. In these
92    /// cases using a [`Query`] object might be easier.
93    ///
94    /// # Example
95    ///
96    /// ```no_run
97    /// # use tiberius::Config;
98    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
99    /// # use std::env;
100    /// # #[tokio::main]
101    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
102    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
103    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
104    /// # );
105    /// # let config = Config::from_ado_string(&c_str)?;
106    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
107    /// # tcp.set_nodelay(true)?;
108    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
109    /// let results = client
110    ///     .execute(
111    ///         "INSERT INTO ##Test (id) VALUES (@P1), (@P2), (@P3)",
112    ///         &[&1i32, &2i32, &3i32],
113    ///     )
114    ///     .await?;
115    /// # Ok(())
116    /// # }
117    /// ```
118    ///
119    /// [`ExecuteResult`]: struct.ExecuteResult.html
120    /// [`ToSql`]: trait.ToSql.html
121    /// [`FromSql`]: trait.FromSql.html
122    /// [`Query`]: struct.Query.html
123    pub async fn execute<'a>(
124        &mut self,
125        query: impl Into<Cow<'a, str>>,
126        params: &[&dyn ToSql],
127    ) -> crate::Result<ExecuteResult> {
128        self.connection.flush_stream().await?;
129        let rpc_params = Self::rpc_params(query);
130
131        let params = params.iter().map(|s| s.to_sql());
132        self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
133            .await?;
134
135        ExecuteResult::new(&mut self.connection).await
136    }
137
138    /// Executes SQL statements in the SQL Server, returning resulting rows.
139    /// Useful for `SELECT` statements. The `query` can define the parameter
140    /// placement by annotating them with `@PN`, where N is the index of the
141    /// parameter, starting from `1`. If executing multiple queries at a time,
142    /// delimit them with `;` and refer to [`QueryStream`] on proper stream
143    /// handling.
144    ///
145    /// For mapping of Rust types when writing, see the documentation for
146    /// [`ToSql`]. For reading data from the database, see the documentation for
147    /// [`FromSql`].
148    ///
149    /// This API can be cumbersome for dynamic query parameters. In these cases,
150    /// if fighting too much with the compiler, using a [`Query`] object might be
151    /// easier.
152    ///
153    /// # Example
154    ///
155    /// ```
156    /// # use tiberius::Config;
157    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
158    /// # use std::env;
159    /// # #[tokio::main]
160    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
161    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
162    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
163    /// # );
164    /// # let config = Config::from_ado_string(&c_str)?;
165    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
166    /// # tcp.set_nodelay(true)?;
167    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
168    /// let stream = client
169    ///     .query(
170    ///         "SELECT @P1, @P2, @P3",
171    ///         &[&1i32, &2i32, &3i32],
172    ///     )
173    ///     .await?;
174    /// # Ok(())
175    /// # }
176    /// ```
177    ///
178    /// [`QueryStream`]: struct.QueryStream.html
179    /// [`Query`]: struct.Query.html
180    /// [`ToSql`]: trait.ToSql.html
181    /// [`FromSql`]: trait.FromSql.html
182    pub async fn query<'a, 'b>(
183        &'a mut self,
184        query: impl Into<Cow<'b, str>>,
185        params: &'b [&'b dyn ToSql],
186    ) -> crate::Result<QueryStream<'a>>
187    where
188        'a: 'b,
189    {
190        self.connection.flush_stream().await?;
191        let rpc_params = Self::rpc_params(query);
192
193        let params = params.iter().map(|p| p.to_sql());
194        self.rpc_perform_query(RpcProcId::ExecuteSQL, rpc_params, params)
195            .await?;
196
197        let ts = TokenStream::new(&mut self.connection);
198        let mut result = QueryStream::new(ts.try_unfold());
199        result.forward_to_metadata().await?;
200
201        Ok(result)
202    }
203
204    /// Execute multiple queries, delimited with `;` and return multiple result
205    /// sets; one for each query.
206    ///
207    /// # Example
208    ///
209    /// ```
210    /// # use tiberius::Config;
211    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
212    /// # use std::env;
213    /// # #[tokio::main]
214    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
215    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
216    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
217    /// # );
218    /// # let config = Config::from_ado_string(&c_str)?;
219    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
220    /// # tcp.set_nodelay(true)?;
221    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
222    /// let row = client.simple_query("SELECT 1 AS col").await?.into_row().await?.unwrap();
223    /// assert_eq!(Some(1i32), row.get("col"));
224    /// # Ok(())
225    /// # }
226    /// ```
227    ///
228    /// # Warning
229    ///
230    /// Do not use this with any user specified input. Please resort to prepared
231    /// statements using the [`query`] method.
232    ///
233    /// [`query`]: #method.query
234    pub async fn simple_query<'a, 'b>(
235        &'a mut self,
236        query: impl Into<Cow<'b, str>>,
237    ) -> crate::Result<QueryStream<'a>>
238    where
239        'a: 'b,
240    {
241        self.connection.flush_stream().await?;
242
243        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
244
245        let id = self.connection.context_mut().next_packet_id();
246        self.connection.send(PacketHeader::batch(id), req).await?;
247
248        let ts = TokenStream::new(&mut self.connection);
249
250        let mut result = QueryStream::new(ts.try_unfold());
251        result.forward_to_metadata().await?;
252
253        Ok(result)
254    }
255
256    /// Execute a `BULK INSERT` statement, efficiantly storing a large number of
257    /// rows to a specified table. Note: make sure the input row follows the same
258    /// schema as the table, otherwise calling `send()` will return an error.
259    ///
260    /// # Example
261    ///
262    /// ```
263    /// # use tiberius::{Config, IntoRow};
264    /// # use tokio_util::compat::TokioAsyncWriteCompatExt;
265    /// # use std::env;
266    /// # #[tokio::main]
267    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
268    /// # let c_str = env::var("TIBERIUS_TEST_CONNECTION_STRING").unwrap_or(
269    /// #     "server=tcp:localhost,1433;integratedSecurity=true;TrustServerCertificate=true".to_owned(),
270    /// # );
271    /// # let config = Config::from_ado_string(&c_str)?;
272    /// # let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
273    /// # tcp.set_nodelay(true)?;
274    /// # let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
275    /// let create_table = r#"
276    ///     CREATE TABLE ##bulk_test (
277    ///         id INT IDENTITY PRIMARY KEY,
278    ///         val INT NOT NULL
279    ///     )
280    /// "#;
281    ///
282    /// client.simple_query(create_table).await?;
283    ///
284    /// // Start the bulk insert with the client.
285    /// let mut req = client.bulk_insert("##bulk_test").await?;
286    ///
287    /// for i in [0i32, 1i32, 2i32] {
288    ///     let row = (i).into_row();
289    ///
290    ///     // The request will handle flushing to the wire in an optimal way,
291    ///     // balancing between memory usage and IO performance.
292    ///     req.send(row).await?;
293    /// }
294    ///
295    /// // The request must be finalized.
296    /// let res = req.finalize().await?;
297    /// assert_eq!(3, res.total());
298    /// # Ok(())
299    /// # }
300    /// ```
301    pub async fn bulk_insert<'a>(
302        &'a mut self,
303        table: &str,
304    ) -> crate::Result<BulkLoadRequest<'a, S>> {
305        let columns = self.bulk_insert_columns(table).await?;
306        self.bulk_insert_with_columns(table, columns).await
307    }
308
309    /// Returns updateable target column metadata for a future bulk insert.
310    ///
311    /// This method only sends a metadata query. It does not start the
312    /// `INSERT BULK` protocol flow, so callers can validate the target table
313    /// and fail without needing to finalize an empty bulk-load request.
314    pub async fn bulk_insert_columns(
315        &mut self,
316        table: &str,
317    ) -> crate::Result<BulkLoadColumns<'static>> {
318        self.connection.flush_stream().await?;
319
320        let query = format!("SELECT TOP 0 * FROM {}", table);
321
322        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
323
324        let id = self.connection.context_mut().next_packet_id();
325        self.connection.send(PacketHeader::batch(id), req).await?;
326
327        let token_stream = TokenStream::new(&mut self.connection).try_unfold();
328
329        let columns = token_stream
330            .try_fold(None, |mut columns, token| async move {
331                if let ReceivedToken::NewResultset(metadata) = token {
332                    columns = Some(metadata.columns.clone());
333                };
334
335                Ok(columns)
336            })
337            .await?;
338
339        // now start bulk upload
340        let columns: Vec<_> = columns
341            .ok_or_else(|| {
342                crate::Error::Protocol("expecting column metadata from query but not found".into())
343            })?
344            .into_iter()
345            .filter(|column| column.base.flags.contains(ColumnFlag::Updateable))
346            .collect();
347
348        Ok(BulkLoadColumns::new(columns))
349    }
350
351    /// Starts a bulk insert using previously discovered target columns.
352    pub async fn bulk_insert_with_columns<'a>(
353        &'a mut self,
354        table: &str,
355        columns: BulkLoadColumns<'a>,
356    ) -> crate::Result<BulkLoadRequest<'a, S>> {
357        let columns = columns.into_inner();
358
359        self.connection.flush_stream().await?;
360        let col_data = columns.iter().map(|c| format!("{}", c)).join(", ");
361        let query = format!("INSERT BULK {} ({})", table, col_data);
362
363        let req = BatchRequest::new(query, self.connection.context().transaction_descriptor());
364        let id = self.connection.context_mut().next_packet_id();
365
366        self.connection.send(PacketHeader::batch(id), req).await?;
367
368        let ts = TokenStream::new(&mut self.connection);
369        ts.flush_done().await?;
370
371        BulkLoadRequest::new(&mut self.connection, columns)
372    }
373
374    /// Closes this database connection explicitly.
375    pub async fn close(self) -> crate::Result<()> {
376        self.connection.close().await
377    }
378
379    pub(crate) fn rpc_params<'a>(query: impl Into<Cow<'a, str>>) -> Vec<RpcParam<'a>> {
380        vec![
381            RpcParam {
382                name: Cow::Borrowed("stmt"),
383                flags: BitFlags::empty(),
384                value: ColumnData::String(Some(query.into())),
385            },
386            RpcParam {
387                name: Cow::Borrowed("params"),
388                flags: BitFlags::empty(),
389                value: ColumnData::I32(Some(0)),
390            },
391        ]
392    }
393
394    pub(crate) async fn rpc_perform_query<'a, 'b>(
395        &'a mut self,
396        proc_id: RpcProcId,
397        mut rpc_params: Vec<RpcParam<'b>>,
398        params: impl Iterator<Item = ColumnData<'b>>,
399    ) -> crate::Result<()>
400    where
401        'a: 'b,
402    {
403        let mut param_str = String::new();
404
405        for (i, param) in params.enumerate() {
406            if i > 0 {
407                param_str.push(',')
408            }
409            param_str.push_str(&format!("@P{} ", i + 1));
410            param_str.push_str(&param.type_name());
411
412            rpc_params.push(RpcParam {
413                name: Cow::Owned(format!("@P{}", i + 1)),
414                flags: BitFlags::empty(),
415                value: param,
416            });
417        }
418
419        if let Some(params) = rpc_params.iter_mut().find(|x| x.name == "params") {
420            params.value = ColumnData::String(Some(param_str.into()));
421        }
422
423        let req = TokenRpcRequest::new(
424            proc_id,
425            rpc_params,
426            self.connection.context().transaction_descriptor(),
427        );
428
429        let id = self.connection.context_mut().next_packet_id();
430        self.connection.send(PacketHeader::rpc(id), req).await?;
431
432        Ok(())
433    }
434}