sqlx_postgres/
copy.rs

1use std::borrow::Cow;
2use std::ops::{Deref, DerefMut};
3
4use futures_core::future::BoxFuture;
5use futures_core::stream::BoxStream;
6
7use sqlx_core::bytes::{BufMut, Bytes};
8
9use crate::connection::PgConnection;
10use crate::error::{Error, Result};
11use crate::ext::async_stream::TryAsyncStream;
12use crate::io::AsyncRead;
13use crate::message::{
14    BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
15    CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
16};
17use crate::pool::{Pool, PoolConnection};
18use crate::Postgres;
19
20impl PgConnection {
21    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
22    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
23    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
24    ///
25    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
26    /// returned.
27    ///
28    /// Command examples and accepted formats for `COPY` data are shown here:
29    /// <https://www.postgresql.org/docs/current/sql-copy.html>
30    ///
31    /// ### Note
32    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
33    /// will return an error the next time it is used.
34    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
35        PgCopyIn::begin(self, statement).await
36    }
37
38    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
39    /// from Postgres. This is a more efficient way to export data from Postgres but
40    /// arrives in chunks of one of a few data formats (text/CSV/binary).
41    ///
42    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
43    /// an error is returned.
44    ///
45    /// Note that once this process has begun, unless you read the stream to completion,
46    /// it can only be canceled in two ways:
47    ///
48    /// 1. by closing the connection, or:
49    /// 2. by using another connection to kill the server process that is sending the data as shown
50    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
51    ///
52    /// If you don't read the stream to completion, the next time the connection is used it will
53    /// need to read and discard all the remaining queued data, which could take some time.
54    ///
55    /// Command examples and accepted formats for `COPY` data are shown here:
56    /// <https://www.postgresql.org/docs/current/sql-copy.html>
57    #[allow(clippy::needless_lifetimes)]
58    pub async fn copy_out_raw<'c>(
59        &'c mut self,
60        statement: &str,
61    ) -> Result<BoxStream<'c, Result<Bytes>>> {
62        pg_begin_copy_out(self, statement).await
63    }
64}
65
66/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`][crate::PgPool].
67///
68/// This is a replacement for the inherent methods on `PgPool` which could not exist
69/// once the Postgres driver was moved out into its own crate.
70pub trait PgPoolCopyExt {
71    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
72    /// This is a more efficient way to import data into Postgres as compared to
73    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
74    ///
75    /// A single connection will be checked out for the duration.
76    ///
77    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
78    /// returned.
79    ///
80    /// Command examples and accepted formats for `COPY` data are shown here:
81    /// <https://www.postgresql.org/docs/current/sql-copy.html>
82    ///
83    /// ### Note
84    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
85    /// will return an error the next time it is used.
86    fn copy_in_raw<'a>(
87        &'a self,
88        statement: &'a str,
89    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
90
91    /// Issue a `COPY TO STDOUT` statement and begin streaming data
92    /// from Postgres. This is a more efficient way to export data from Postgres but
93    /// arrives in chunks of one of a few data formats (text/CSV/binary).
94    ///
95    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
96    /// an error is returned.
97    ///
98    /// Note that once this process has begun, unless you read the stream to completion,
99    /// it can only be canceled in two ways:
100    ///
101    /// 1. by closing the connection, or:
102    /// 2. by using another connection to kill the server process that is sending the data as shown
103    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
104    ///
105    /// If you don't read the stream to completion, the next time the connection is used it will
106    /// need to read and discard all the remaining queued data, which could take some time.
107    ///
108    /// Command examples and accepted formats for `COPY` data are shown here:
109    /// <https://www.postgresql.org/docs/current/sql-copy.html>
110    fn copy_out_raw<'a>(
111        &'a self,
112        statement: &'a str,
113    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
114}
115
116impl PgPoolCopyExt for Pool<Postgres> {
117    fn copy_in_raw<'a>(
118        &'a self,
119        statement: &'a str,
120    ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
121        Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
122    }
123
124    fn copy_out_raw<'a>(
125        &'a self,
126        statement: &'a str,
127    ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
128        Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
129    }
130}
131
132// (1 GiB - 1) - 1 - length prefix (4 bytes)
133pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4;
134
135/// A connection in streaming `COPY FROM STDIN` mode.
136///
137/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
138///
139/// ### Note
140/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
141/// will return an error the next time it is used.
142#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
143pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
144    conn: Option<C>,
145    response: CopyResponseData,
146}
147
148impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
149    async fn begin(mut conn: C, statement: &str) -> Result<Self> {
150        conn.wait_until_ready().await?;
151        conn.inner.stream.send(Query(statement)).await?;
152
153        let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
154            Ok(res) => res.0,
155            Err(e) => {
156                conn.inner.stream.recv().await?;
157                return Err(e);
158            }
159        };
160
161        Ok(PgCopyIn {
162            conn: Some(conn),
163            response,
164        })
165    }
166
167    /// Returns `true` if Postgres is expecting data in text or CSV format.
168    pub fn is_textual(&self) -> bool {
169        self.response.format == 0
170    }
171
172    /// Returns the number of columns expected in the input.
173    pub fn num_columns(&self) -> usize {
174        assert_eq!(
175            self.response.num_columns.unsigned_abs() as usize,
176            self.response.format_codes.len(),
177            "num_columns does not match format_codes.len()"
178        );
179        self.response.format_codes.len()
180    }
181
182    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
183    ///
184    /// ### Panics
185    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
186    pub fn column_is_textual(&self, column: usize) -> bool {
187        self.response.format_codes[column] == 0
188    }
189
190    /// Send a chunk of `COPY` data.
191    ///
192    /// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6
193    /// bytes) and may be partially sent if this call is cancelled.
194    ///
195    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
196    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
197        for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
198            self.conn
199                .as_deref_mut()
200                .expect("send_data: conn taken")
201                .inner
202                .stream
203                .send(CopyData(chunk))
204                .await?;
205        }
206
207        Ok(self)
208    }
209
210    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
211    ///
212    /// `source` will be read to the end.
213    ///
214    /// ### Note: Completion Step Required
215    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
216    ///
217    /// ### Note: Runtime Features
218    /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
219    /// depending on which runtime feature is used.
220    ///
221    /// The runtime features _used_ to be mutually exclusive, but are no longer.
222    /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
223    /// takes precedent.
224    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
225        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
226        loop {
227            let buf = conn.inner.stream.write_buffer_mut();
228
229            // Write the CopyData format code and reserve space for the length.
230            // This may end up sending an empty `CopyData` packet if, after this point,
231            // we get canceled or read 0 bytes, but that should be fine.
232            buf.put_slice(b"d\0\0\0\x04");
233
234            let read = buf.read_from(&mut source).await?;
235
236            if read == 0 {
237                break;
238            }
239
240            // Write the length
241            let read32 = i32::try_from(read)
242                .map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;
243
244            (&mut buf.get_mut()[1..]).put_i32(read32 + 4);
245
246            conn.inner.stream.flush().await?;
247        }
248
249        Ok(self)
250    }
251
252    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
253    ///
254    /// The given message can be used for indicating the reason for the abort in the database logs.
255    ///
256    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
257    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
258        let mut conn = self
259            .conn
260            .take()
261            .expect("PgCopyIn::fail_with: conn taken illegally");
262
263        conn.inner.stream.send(CopyFail::new(msg)).await?;
264
265        match conn.inner.stream.recv().await {
266            Ok(msg) => Err(err_protocol!(
267                "fail_with: expected ErrorResponse, got: {:?}",
268                msg.format
269            )),
270            Err(Error::Database(e)) => {
271                match e.code() {
272                    Some(Cow::Borrowed("57014")) => {
273                        // postgres abort received error code
274                        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
275                        Ok(())
276                    }
277                    _ => Err(Error::Database(e)),
278                }
279            }
280            Err(e) => Err(e),
281        }
282    }
283
284    /// Signal that the `COPY` process is complete.
285    ///
286    /// The number of rows affected is returned.
287    pub async fn finish(mut self) -> Result<u64> {
288        let mut conn = self
289            .conn
290            .take()
291            .expect("CopyWriter::finish: conn taken illegally");
292
293        conn.inner.stream.send(CopyDone).await?;
294        let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
295            Ok(cc) => cc,
296            Err(e) => {
297                conn.inner.stream.recv().await?;
298                return Err(e);
299            }
300        };
301
302        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
303
304        Ok(cc.rows_affected())
305    }
306}
307
308impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
309    fn drop(&mut self) {
310        if let Some(mut conn) = self.conn.take() {
311            conn.inner
312                .stream
313                .write_msg(CopyFail::new(
314                    "PgCopyIn dropped without calling finish() or fail()",
315                ))
316                .expect("BUG: PgCopyIn abort message should not be too large");
317        }
318    }
319}
320
321async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
322    mut conn: C,
323    statement: &str,
324) -> Result<BoxStream<'c, Result<Bytes>>> {
325    conn.wait_until_ready().await?;
326    conn.inner.stream.send(Query(statement)).await?;
327
328    let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
329
330    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
331        loop {
332            match conn.inner.stream.recv().await {
333                Err(e) => {
334                    conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
335                    return Err(e);
336                },
337                Ok(msg) => match msg.format {
338                    BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
339                    BackendMessageFormat::CopyDone => {
340                        let _ = msg.decode::<CopyDone>()?;
341                        conn.inner.stream.recv_expect::<CommandComplete>().await?;
342                        conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
343                        return Ok(())
344                    },
345                    _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
346                }
347            }
348        }
349    };
350
351    Ok(Box::pin(stream))
352}