tokio_postgres/
client.rs

1use crate::codec::BackendMessages;
2use crate::config::{SslMode, SslNegotiation};
3use crate::connection::{Request, RequestMessages};
4use crate::copy_out::CopyOutStream;
5#[cfg(feature = "runtime")]
6use crate::keepalive::KeepaliveConfig;
7use crate::query::RowStream;
8use crate::simple_query::SimpleQueryStream;
9#[cfg(feature = "runtime")]
10use crate::tls::MakeTlsConnect;
11use crate::tls::TlsConnect;
12use crate::types::{Oid, ToSql, Type};
13#[cfg(feature = "runtime")]
14use crate::Socket;
15use crate::{
16    copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
17    Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
18};
19use bytes::{Buf, BytesMut};
20use fallible_iterator::FallibleIterator;
21use futures_channel::mpsc;
22use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt};
23use parking_lot::Mutex;
24use postgres_protocol::message::backend::Message;
25use postgres_types::BorrowToSql;
26use std::collections::HashMap;
27use std::fmt;
28#[cfg(feature = "runtime")]
29use std::net::IpAddr;
30#[cfg(feature = "runtime")]
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::task::{Context, Poll};
34#[cfg(feature = "runtime")]
35use std::time::Duration;
36use tokio::io::{AsyncRead, AsyncWrite};
37
38pub struct Responses {
39    receiver: mpsc::Receiver<BackendMessages>,
40    cur: BackendMessages,
41}
42
43impl Responses {
44    pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
45        loop {
46            match self.cur.next().map_err(Error::parse)? {
47                Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
48                Some(message) => return Poll::Ready(Ok(message)),
49                None => {}
50            }
51
52            match ready!(self.receiver.poll_next_unpin(cx)) {
53                Some(messages) => self.cur = messages,
54                None => return Poll::Ready(Err(Error::closed())),
55            }
56        }
57    }
58
59    pub async fn next(&mut self) -> Result<Message, Error> {
60        future::poll_fn(|cx| self.poll_next(cx)).await
61    }
62}
63
64/// A cache of type info and prepared statements for fetching type info
65/// (corresponding to the queries in the [prepare](prepare) module).
66#[derive(Default)]
67struct CachedTypeInfo {
68    /// A statement for basic information for a type from its
69    /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
70    /// fallback).
71    typeinfo: Option<Statement>,
72    /// A statement for getting information for a composite type from its OID.
73    /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
74    typeinfo_composite: Option<Statement>,
75    /// A statement for getting information for a composite type from its OID.
76    /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
77    /// its fallback).
78    typeinfo_enum: Option<Statement>,
79
80    /// Cache of types already looked up.
81    types: HashMap<Oid, Type>,
82}
83
84pub struct InnerClient {
85    sender: mpsc::UnboundedSender<Request>,
86    cached_typeinfo: Mutex<CachedTypeInfo>,
87
88    /// A buffer to use when writing out postgres commands.
89    buffer: Mutex<BytesMut>,
90}
91
92impl InnerClient {
93    pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
94        let (sender, receiver) = mpsc::channel(1);
95        let request = Request { messages, sender };
96        self.sender
97            .unbounded_send(request)
98            .map_err(|_| Error::closed())?;
99
100        Ok(Responses {
101            receiver,
102            cur: BackendMessages::empty(),
103        })
104    }
105
106    pub fn typeinfo(&self) -> Option<Statement> {
107        self.cached_typeinfo.lock().typeinfo.clone()
108    }
109
110    pub fn set_typeinfo(&self, statement: &Statement) {
111        self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
112    }
113
114    pub fn typeinfo_composite(&self) -> Option<Statement> {
115        self.cached_typeinfo.lock().typeinfo_composite.clone()
116    }
117
118    pub fn set_typeinfo_composite(&self, statement: &Statement) {
119        self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
120    }
121
122    pub fn typeinfo_enum(&self) -> Option<Statement> {
123        self.cached_typeinfo.lock().typeinfo_enum.clone()
124    }
125
126    pub fn set_typeinfo_enum(&self, statement: &Statement) {
127        self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
128    }
129
130    pub fn type_(&self, oid: Oid) -> Option<Type> {
131        self.cached_typeinfo.lock().types.get(&oid).cloned()
132    }
133
134    pub fn set_type(&self, oid: Oid, type_: &Type) {
135        self.cached_typeinfo.lock().types.insert(oid, type_.clone());
136    }
137
138    pub fn clear_type_cache(&self) {
139        self.cached_typeinfo.lock().types.clear();
140    }
141
142    /// Call the given function with a buffer to be used when writing out
143    /// postgres commands.
144    pub fn with_buf<F, R>(&self, f: F) -> R
145    where
146        F: FnOnce(&mut BytesMut) -> R,
147    {
148        let mut buffer = self.buffer.lock();
149        let r = f(&mut buffer);
150        buffer.clear();
151        r
152    }
153}
154
155#[cfg(feature = "runtime")]
156#[derive(Clone)]
157pub(crate) struct SocketConfig {
158    pub addr: Addr,
159    pub hostname: Option<String>,
160    pub port: u16,
161    pub connect_timeout: Option<Duration>,
162    pub tcp_user_timeout: Option<Duration>,
163    pub keepalive: Option<KeepaliveConfig>,
164}
165
166#[cfg(feature = "runtime")]
167#[derive(Clone)]
168pub(crate) enum Addr {
169    Tcp(IpAddr),
170    #[cfg(unix)]
171    Unix(PathBuf),
172}
173
174/// An asynchronous PostgreSQL client.
175///
176/// The client is one half of what is returned when a connection is established. Users interact with the database
177/// through this client object.
178pub struct Client {
179    inner: Arc<InnerClient>,
180    #[cfg(feature = "runtime")]
181    socket_config: Option<SocketConfig>,
182    ssl_mode: SslMode,
183    ssl_negotiation: SslNegotiation,
184    process_id: i32,
185    secret_key: i32,
186}
187
188impl Client {
189    pub(crate) fn new(
190        sender: mpsc::UnboundedSender<Request>,
191        ssl_mode: SslMode,
192        ssl_negotiation: SslNegotiation,
193        process_id: i32,
194        secret_key: i32,
195    ) -> Client {
196        Client {
197            inner: Arc::new(InnerClient {
198                sender,
199                cached_typeinfo: Default::default(),
200                buffer: Default::default(),
201            }),
202            #[cfg(feature = "runtime")]
203            socket_config: None,
204            ssl_mode,
205            ssl_negotiation,
206            process_id,
207            secret_key,
208        }
209    }
210
211    pub(crate) fn inner(&self) -> &Arc<InnerClient> {
212        &self.inner
213    }
214
215    #[cfg(feature = "runtime")]
216    pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
217        self.socket_config = Some(socket_config);
218    }
219
220    /// Creates a new prepared statement.
221    ///
222    /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
223    /// which are set when executed. Prepared statements can only be used with the connection that created them.
224    pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
225        self.prepare_typed(query, &[]).await
226    }
227
228    /// Like `prepare`, but allows the types of query parameters to be explicitly specified.
229    ///
230    /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
231    /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
232    pub async fn prepare_typed(
233        &self,
234        query: &str,
235        parameter_types: &[Type],
236    ) -> Result<Statement, Error> {
237        prepare::prepare(&self.inner, query, parameter_types).await
238    }
239
240    /// Executes a statement, returning a vector of the resulting rows.
241    ///
242    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
243    /// provided, 1-indexed.
244    ///
245    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
246    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
247    /// with the `prepare` method.
248    pub async fn query<T>(
249        &self,
250        statement: &T,
251        params: &[&(dyn ToSql + Sync)],
252    ) -> Result<Vec<Row>, Error>
253    where
254        T: ?Sized + ToStatement,
255    {
256        self.query_raw(statement, slice_iter(params))
257            .await?
258            .try_collect()
259            .await
260    }
261
262    /// Executes a statement which returns a single row, returning it.
263    ///
264    /// Returns an error if the query does not return exactly one row.
265    ///
266    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
267    /// provided, 1-indexed.
268    ///
269    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
270    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
271    /// with the `prepare` method.
272    pub async fn query_one<T>(
273        &self,
274        statement: &T,
275        params: &[&(dyn ToSql + Sync)],
276    ) -> Result<Row, Error>
277    where
278        T: ?Sized + ToStatement,
279    {
280        self.query_opt(statement, params)
281            .await
282            .and_then(|res| res.ok_or_else(Error::row_count))
283    }
284
285    /// Executes a statements which returns zero or one rows, returning it.
286    ///
287    /// Returns an error if the query returns more than one row.
288    ///
289    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
290    /// provided, 1-indexed.
291    ///
292    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
293    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
294    /// with the `prepare` method.
295    pub async fn query_opt<T>(
296        &self,
297        statement: &T,
298        params: &[&(dyn ToSql + Sync)],
299    ) -> Result<Option<Row>, Error>
300    where
301        T: ?Sized + ToStatement,
302    {
303        let stream = self.query_raw(statement, slice_iter(params)).await?;
304        pin_mut!(stream);
305
306        let mut first = None;
307
308        // Originally this was two calls to `try_next().await?`,
309        // once for the first element, and second to error if more than one.
310        //
311        // However, this new form with only one .await in a loop generates
312        // slightly smaller codegen/stack usage for the resulting future.
313        while let Some(row) = stream.try_next().await? {
314            if first.is_some() {
315                return Err(Error::row_count());
316            }
317
318            first = Some(row);
319        }
320
321        Ok(first)
322    }
323
324    /// The maximally flexible version of [`query`].
325    ///
326    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
327    /// provided, 1-indexed.
328    ///
329    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
330    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
331    /// with the `prepare` method.
332    ///
333    /// [`query`]: #method.query
334    ///
335    /// # Examples
336    ///
337    /// ```no_run
338    /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
339    /// use futures_util::{pin_mut, TryStreamExt};
340    ///
341    /// let params: Vec<String> = vec![
342    ///     "first param".into(),
343    ///     "second param".into(),
344    /// ];
345    /// let mut it = client.query_raw(
346    ///     "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
347    ///     params,
348    /// ).await?;
349    ///
350    /// pin_mut!(it);
351    /// while let Some(row) = it.try_next().await? {
352    ///     let foo: i32 = row.get("foo");
353    ///     println!("foo: {}", foo);
354    /// }
355    /// # Ok(())
356    /// # }
357    /// ```
358    pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
359    where
360        T: ?Sized + ToStatement,
361        P: BorrowToSql,
362        I: IntoIterator<Item = P>,
363        I::IntoIter: ExactSizeIterator,
364    {
365        let statement = statement.__convert().into_statement(self).await?;
366        query::query(&self.inner, statement, params).await
367    }
368
369    /// Like `query`, but requires the types of query parameters to be explicitly specified.
370    ///
371    /// Compared to `query`, this method allows performing queries without three round trips (for
372    /// prepare, execute, and close) by requiring the caller to specify parameter values along with
373    /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't
374    /// supported (such as Cloudflare Workers with Hyperdrive).
375    ///
376    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the
377    /// parameter of the list provided, 1-indexed.
378    pub async fn query_typed(
379        &self,
380        query: &str,
381        params: &[(&(dyn ToSql + Sync), Type)],
382    ) -> Result<Vec<Row>, Error> {
383        self.query_typed_raw(query, params.iter().map(|(v, t)| (*v, t.clone())))
384            .await?
385            .try_collect()
386            .await
387    }
388
389    /// The maximally flexible version of [`query_typed`].
390    ///
391    /// Compared to `query`, this method allows performing queries without three round trips (for
392    /// prepare, execute, and close) by requiring the caller to specify parameter values along with
393    /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't
394    /// supported (such as Cloudflare Workers with Hyperdrive).
395    ///
396    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the
397    /// parameter of the list provided, 1-indexed.
398    ///
399    /// [`query_typed`]: #method.query_typed
400    ///
401    /// # Examples
402    ///
403    /// ```no_run
404    /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
405    /// use futures_util::{pin_mut, TryStreamExt};
406    /// use tokio_postgres::types::Type;
407    ///
408    /// let params: Vec<(String, Type)> = vec![
409    ///     ("first param".into(), Type::TEXT),
410    ///     ("second param".into(), Type::TEXT),
411    /// ];
412    /// let mut it = client.query_typed_raw(
413    ///     "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
414    ///     params,
415    /// ).await?;
416    ///
417    /// pin_mut!(it);
418    /// while let Some(row) = it.try_next().await? {
419    ///     let foo: i32 = row.get("foo");
420    ///     println!("foo: {}", foo);
421    /// }
422    /// # Ok(())
423    /// # }
424    /// ```
425    pub async fn query_typed_raw<P, I>(&self, query: &str, params: I) -> Result<RowStream, Error>
426    where
427        P: BorrowToSql,
428        I: IntoIterator<Item = (P, Type)>,
429    {
430        query::query_typed(&self.inner, query, params).await
431    }
432
433    /// Executes a statement, returning the number of rows modified.
434    ///
435    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
436    /// provided, 1-indexed.
437    ///
438    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
439    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
440    /// with the `prepare` method.
441    ///
442    /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
443    pub async fn execute<T>(
444        &self,
445        statement: &T,
446        params: &[&(dyn ToSql + Sync)],
447    ) -> Result<u64, Error>
448    where
449        T: ?Sized + ToStatement,
450    {
451        self.execute_raw(statement, slice_iter(params)).await
452    }
453
454    /// The maximally flexible version of [`execute`].
455    ///
456    /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
457    /// provided, 1-indexed.
458    ///
459    /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
460    /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
461    /// with the `prepare` method.
462    ///
463    /// [`execute`]: #method.execute
464    pub async fn execute_raw<T, P, I>(&self, statement: &T, params: I) -> Result<u64, Error>
465    where
466        T: ?Sized + ToStatement,
467        P: BorrowToSql,
468        I: IntoIterator<Item = P>,
469        I::IntoIter: ExactSizeIterator,
470    {
471        let statement = statement.__convert().into_statement(self).await?;
472        query::execute(self.inner(), statement, params).await
473    }
474
475    /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
476    ///
477    /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must*
478    /// be explicitly completed via the `Sink::close` or `finish` methods. If it is not, the copy will be aborted.
479    pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
480    where
481        T: ?Sized + ToStatement,
482        U: Buf + 'static + Send,
483    {
484        let statement = statement.__convert().into_statement(self).await?;
485        copy_in::copy_in(self.inner(), statement).await
486    }
487
488    /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
489    ///
490    /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any.
491    pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, Error>
492    where
493        T: ?Sized + ToStatement,
494    {
495        let statement = statement.__convert().into_statement(self).await?;
496        copy_out::copy_out(self.inner(), statement).await
497    }
498
499    /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
500    ///
501    /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
502    /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
503    /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
504    /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
505    /// or a row of data. This preserves the framing between the separate statements in the request.
506    ///
507    /// # Warning
508    ///
509    /// Prepared statements should be use for any query which contains user-specified data, as they provided the
510    /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
511    /// them to this method!
512    pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
513        self.simple_query_raw(query).await?.try_collect().await
514    }
515
516    pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
517        simple_query::simple_query(self.inner(), query).await
518    }
519
520    /// Executes a sequence of SQL statements using the simple query protocol.
521    ///
522    /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
523    /// point. This is intended for use when, for example, initializing a database schema.
524    ///
525    /// # Warning
526    ///
527    /// Prepared statements should be use for any query which contains user-specified data, as they provided the
528    /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
529    /// them to this method!
530    pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
531        simple_query::batch_execute(self.inner(), query).await
532    }
533
534    /// Begins a new database transaction.
535    ///
536    /// The transaction will roll back by default - use the `commit` method to commit it.
537    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
538        self.build_transaction().start().await
539    }
540
541    /// Returns a builder for a transaction with custom settings.
542    ///
543    /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
544    /// attributes.
545    pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
546        TransactionBuilder::new(self)
547    }
548
549    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
550    /// connection associated with this client.
551    pub fn cancel_token(&self) -> CancelToken {
552        CancelToken {
553            #[cfg(feature = "runtime")]
554            socket_config: self.socket_config.clone(),
555            ssl_mode: self.ssl_mode,
556            ssl_negotiation: self.ssl_negotiation,
557            process_id: self.process_id,
558            secret_key: self.secret_key,
559        }
560    }
561
562    /// Attempts to cancel an in-progress query.
563    ///
564    /// The server provides no information about whether a cancellation attempt was successful or not. An error will
565    /// only be returned if the client was unable to connect to the database.
566    ///
567    /// Requires the `runtime` Cargo feature (enabled by default).
568    #[cfg(feature = "runtime")]
569    #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
570    pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
571    where
572        T: MakeTlsConnect<Socket>,
573    {
574        self.cancel_token().cancel_query(tls).await
575    }
576
577    /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
578    /// connection itself.
579    #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")]
580    pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
581    where
582        S: AsyncRead + AsyncWrite + Unpin,
583        T: TlsConnect<S>,
584    {
585        self.cancel_token().cancel_query_raw(stream, tls).await
586    }
587
588    /// Clears the client's type information cache.
589    ///
590    /// When user-defined types are used in a query, the client loads their definitions from the database and caches
591    /// them for the lifetime of the client. If those definitions are changed in the database, this method can be used
592    /// to flush the local cache and allow the new, updated definitions to be loaded.
593    pub fn clear_type_cache(&self) {
594        self.inner().clear_type_cache();
595    }
596
597    /// Determines if the connection to the server has already closed.
598    ///
599    /// In that case, all future queries will fail.
600    pub fn is_closed(&self) -> bool {
601        self.inner.sender.is_closed()
602    }
603
604    #[doc(hidden)]
605    pub fn __private_api_close(&mut self) {
606        self.inner.sender.close_channel()
607    }
608}
609
610impl fmt::Debug for Client {
611    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
612        f.debug_struct("Client").finish()
613    }
614}