sqlx_build_trust_postgres/
copy.rs

1use futures_core::future::BoxFuture;
2use std::borrow::Cow;
3use std::ops::{Deref, DerefMut};
4
5use futures_core::stream::BoxStream;
6use sqlx_core::bytes::{BufMut, Bytes};
7
8use crate::connection::PgConnection;
9use crate::error::{Error, Result};
10use crate::ext::async_stream::TryAsyncStream;
11use crate::io::{AsyncRead, AsyncReadExt};
12use crate::message::{
13    CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
14};
15use crate::pool::{Pool, PoolConnection};
16use crate::Postgres;
17
18impl PgConnection {
19    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
20    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
21    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
22    ///
23    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
24    /// returned.
25    ///
26    /// Command examples and accepted formats for `COPY` data are shown here:
27    /// https://www.postgresql.org/docs/current/sql-copy.html
28    ///
29    /// ### Note
30    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
31    /// will return an error the next time it is used.
32    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
33        PgCopyIn::begin(self, statement).await
34    }
35
36    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
37    /// from Postgres. This is a more efficient way to export data from Postgres but
38    /// arrives in chunks of one of a few data formats (text/CSV/binary).
39    ///
40    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
41    /// an error is returned.
42    ///
43    /// Note that once this process has begun, unless you read the stream to completion,
44    /// it can only be canceled in two ways:
45    ///
46    /// 1. by closing the connection, or:
47    /// 2. by using another connection to kill the server process that is sending the data as shown
48    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
49    ///
50    /// If you don't read the stream to completion, the next time the connection is used it will
51    /// need to read and discard all the remaining queued data, which could take some time.
52    ///
53    /// Command examples and accepted formats for `COPY` data are shown here:
54    /// https://www.postgresql.org/docs/current/sql-copy.html
55    #[allow(clippy::needless_lifetimes)]
56    pub async fn copy_out_raw<'c>(
57        &'c mut self,
58        statement: &str,
59    ) -> Result<BoxStream<'c, Result<Bytes>>> {
60        pg_begin_copy_out(self, statement).await
61    }
62}
63
64/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`].
65///
66/// This is a replacement for the inherent methods on `PgPool` which could not exist
67/// once the Postgres driver was moved out into its own crate.
68pub trait PgPoolCopyExt {
69    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
70    /// This is a more efficient way to import data into Postgres as compared to
71    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
72    ///
73    /// A single connection will be checked out for the duration.
74    ///
75    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
76    /// returned.
77    ///
78    /// Command examples and accepted formats for `COPY` data are shown here:
79    /// https://www.postgresql.org/docs/current/sql-copy.html
80    ///
81    /// ### Note
82    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
83    /// will return an error the next time it is used.
84    fn copy_in_raw<'a>(
85        &'a self,
86        statement: &'a str,
87    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
88
89    /// Issue a `COPY TO STDOUT` statement and begin streaming data
90    /// from Postgres. This is a more efficient way to export data from Postgres but
91    /// arrives in chunks of one of a few data formats (text/CSV/binary).
92    ///
93    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
94    /// an error is returned.
95    ///
96    /// Note that once this process has begun, unless you read the stream to completion,
97    /// it can only be canceled in two ways:
98    ///
99    /// 1. by closing the connection, or:
100    /// 2. by using another connection to kill the server process that is sending the data as shown
101    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
102    ///
103    /// If you don't read the stream to completion, the next time the connection is used it will
104    /// need to read and discard all the remaining queued data, which could take some time.
105    ///
106    /// Command examples and accepted formats for `COPY` data are shown here:
107    /// https://www.postgresql.org/docs/current/sql-copy.html
108    fn copy_out_raw<'a>(
109        &'a self,
110        statement: &'a str,
111    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
112}
113
114impl PgPoolCopyExt for Pool<Postgres> {
115    fn copy_in_raw<'a>(
116        &'a self,
117        statement: &'a str,
118    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
119        Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
120    }
121
122    fn copy_out_raw<'a>(
123        &'a self,
124        statement: &'a str,
125    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
126        Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
127    }
128}
129
130/// A connection in streaming `COPY FROM STDIN` mode.
131///
132/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
133///
134/// ### Note
135/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
136/// will return an error the next time it is used.
137#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
138pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
139    conn: Option<C>,
140    response: CopyResponse,
141}
142
143impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
144    async fn begin(mut conn: C, statement: &str) -> Result<Self> {
145        conn.wait_until_ready().await?;
146        conn.stream.send(Query(statement)).await?;
147
148        let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
149            Ok(res) => res,
150            Err(e) => {
151                conn.stream.recv().await?;
152                return Err(e);
153            }
154        };
155
156        Ok(PgCopyIn {
157            conn: Some(conn),
158            response,
159        })
160    }
161
162    /// Returns `true` if Postgres is expecting data in text or CSV format.
163    pub fn is_textual(&self) -> bool {
164        self.response.format == 0
165    }
166
167    /// Returns the number of columns expected in the input.
168    pub fn num_columns(&self) -> usize {
169        assert_eq!(
170            self.response.num_columns as usize,
171            self.response.format_codes.len(),
172            "num_columns does not match format_codes.len()"
173        );
174        self.response.format_codes.len()
175    }
176
177    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
178    ///
179    /// ### Panics
180    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
181    pub fn column_is_textual(&self, column: usize) -> bool {
182        self.response.format_codes[column] == 0
183    }
184
185    /// Send a chunk of `COPY` data.
186    ///
187    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
188    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
189        self.conn
190            .as_deref_mut()
191            .expect("send_data: conn taken")
192            .stream
193            .send(CopyData(data))
194            .await?;
195
196        Ok(self)
197    }
198
199    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
200    ///
201    /// `source` will be read to the end.
202    ///
203    /// ### Note: Completion Step Required
204    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
205    ///
206    /// ### Note: Runtime Features
207    /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
208    /// depending on which runtime feature is used.
209    ///
210    /// The runtime features _used_ to be mutually exclusive, but are no longer.
211    /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
212    /// takes precedent.
213    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
214        // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
215        struct BufGuard<'s>(&'s mut Vec<u8>);
216
217        impl Drop for BufGuard<'_> {
218            fn drop(&mut self) {
219                self.0.clear()
220            }
221        }
222
223        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
224        loop {
225            let buf = conn.stream.write_buffer_mut();
226
227            // CopyData format code and reserved space for length
228            buf.put_slice(b"d\0\0\0\x04");
229
230            let read = match () {
231                // Tokio lets us read into the buffer without zeroing first
232                #[cfg(feature = "_rt-tokio")]
233                _ => source.read_buf(buf.buf_mut()).await?,
234                #[cfg(not(feature = "_rt-tokio"))]
235                _ => source.read(buf.init_remaining_mut()).await?,
236            };
237
238            if read == 0 {
239                // This will end up sending an empty `CopyData` packet but that should be fine.
240                break;
241            }
242
243            buf.advance(read);
244
245            // Write the length
246            let read32 = u32::try_from(read)
247                .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
248
249            (&mut buf.get_mut()[1..]).put_u32(read32 + 4);
250
251            conn.stream.flush().await?;
252        }
253
254        Ok(self)
255    }
256
257    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
258    ///
259    /// The given message can be used for indicating the reason for the abort in the database logs.
260    ///
261    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
262    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
263        let mut conn = self
264            .conn
265            .take()
266            .expect("PgCopyIn::fail_with: conn taken illegally");
267
268        conn.stream.send(CopyFail::new(msg)).await?;
269
270        match conn.stream.recv().await {
271            Ok(msg) => Err(err_protocol!(
272                "fail_with: expected ErrorResponse, got: {:?}",
273                msg.format
274            )),
275            Err(Error::Database(e)) => {
276                match e.code() {
277                    Some(Cow::Borrowed("57014")) => {
278                        // postgres abort received error code
279                        conn.stream
280                            .recv_expect(MessageFormat::ReadyForQuery)
281                            .await?;
282                        Ok(())
283                    }
284                    _ => Err(Error::Database(e)),
285                }
286            }
287            Err(e) => Err(e),
288        }
289    }
290
291    /// Signal that the `COPY` process is complete.
292    ///
293    /// The number of rows affected is returned.
294    pub async fn finish(mut self) -> Result<u64> {
295        let mut conn = self
296            .conn
297            .take()
298            .expect("CopyWriter::finish: conn taken illegally");
299
300        conn.stream.send(CopyDone).await?;
301        let cc: CommandComplete = match conn
302            .stream
303            .recv_expect(MessageFormat::CommandComplete)
304            .await
305        {
306            Ok(cc) => cc,
307            Err(e) => {
308                conn.stream.recv().await?;
309                return Err(e);
310            }
311        };
312
313        conn.stream
314            .recv_expect(MessageFormat::ReadyForQuery)
315            .await?;
316
317        Ok(cc.rows_affected())
318    }
319}
320
321impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
322    fn drop(&mut self) {
323        if let Some(mut conn) = self.conn.take() {
324            conn.stream.write(CopyFail::new(
325                "PgCopyIn dropped without calling finish() or fail()",
326            ));
327        }
328    }
329}
330
331async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
332    mut conn: C,
333    statement: &str,
334) -> Result<BoxStream<'c, Result<Bytes>>> {
335    conn.wait_until_ready().await?;
336    conn.stream.send(Query(statement)).await?;
337
338    let _: CopyResponse = conn
339        .stream
340        .recv_expect(MessageFormat::CopyOutResponse)
341        .await?;
342
343    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
344        loop {
345            let msg = conn.stream.recv().await?;
346            match msg.format {
347                MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
348                MessageFormat::CopyDone => {
349                    let _ = msg.decode::<CopyDone>()?;
350                    conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
351                    conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
352                    return Ok(())
353                },
354                _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
355            }
356        }
357    };
358
359    Ok(Box::pin(stream))
360}