sqlx_core_oldapi/postgres/
copy.rs

1use crate::error::{Error, Result};
2use crate::ext::async_stream::TryAsyncStream;
3use crate::pool::{Pool, PoolConnection};
4use crate::postgres::connection::PgConnection;
5use crate::postgres::message::{
6    CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
7};
8use crate::postgres::Postgres;
9use bytes::{BufMut, Bytes};
10use futures_core::stream::BoxStream;
11use smallvec::alloc::borrow::Cow;
12use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use std::ops::{Deref, DerefMut};
14
15impl PgConnection {
16    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19    ///
20    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21    /// returned.
22    ///
23    /// Command examples and accepted formats for `COPY` data are shown here:
24    /// https://www.postgresql.org/docs/current/sql-copy.html
25    ///
26    /// ### Note
27    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28    /// will return an error the next time it is used.
29    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30        PgCopyIn::begin(self, statement).await
31    }
32
33    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34    /// from Postgres. This is a more efficient way to export data from Postgres but
35    /// arrives in chunks of one of a few data formats (text/CSV/binary).
36    ///
37    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38    /// an error is returned.
39    ///
40    /// Note that once this process has begun, unless you read the stream to completion,
41    /// it can only be canceled in two ways:
42    ///
43    /// 1. by closing the connection, or:
44    /// 2. by using another connection to kill the server process that is sending the data as shown
45    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46    ///
47    /// If you don't read the stream to completion, the next time the connection is used it will
48    /// need to read and discard all the remaining queued data, which could take some time.
49    ///
50    /// Command examples and accepted formats for `COPY` data are shown here:
51    /// https://www.postgresql.org/docs/current/sql-copy.html
52    #[allow(clippy::needless_lifetimes)]
53    pub async fn copy_out_raw<'c>(
54        &'c mut self,
55        statement: &str,
56    ) -> Result<BoxStream<'c, Result<Bytes>>> {
57        pg_begin_copy_out(self, statement).await
58    }
59}
60
61impl Pool<Postgres> {
62    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
63    /// This is a more efficient way to import data into Postgres as compared to
64    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
65    ///
66    /// A single connection will be checked out for the duration.
67    ///
68    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
69    /// returned.
70    ///
71    /// Command examples and accepted formats for `COPY` data are shown here:
72    /// https://www.postgresql.org/docs/current/sql-copy.html
73    ///
74    /// ### Note
75    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
76    /// will return an error the next time it is used.
77    pub async fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
78        PgCopyIn::begin(self.acquire().await?, statement).await
79    }
80
81    /// Issue a `COPY TO STDOUT` statement and begin streaming data
82    /// from Postgres. This is a more efficient way to export data from Postgres but
83    /// arrives in chunks of one of a few data formats (text/CSV/binary).
84    ///
85    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
86    /// an error is returned.
87    ///
88    /// Note that once this process has begun, unless you read the stream to completion,
89    /// it can only be canceled in two ways:
90    ///
91    /// 1. by closing the connection, or:
92    /// 2. by using another connection to kill the server process that is sending the data as shown
93    ///    [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
94    ///
95    /// If you don't read the stream to completion, the next time the connection is used it will
96    /// need to read and discard all the remaining queued data, which could take some time.
97    ///
98    /// Command examples and accepted formats for `COPY` data are shown here:
99    /// https://www.postgresql.org/docs/current/sql-copy.html
100    pub async fn copy_out_raw(&self, statement: &str) -> Result<BoxStream<'static, Result<Bytes>>> {
101        pg_begin_copy_out(self.acquire().await?, statement).await
102    }
103}
104
105/// A connection in streaming `COPY FROM STDIN` mode.
106///
107/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
108///
109/// ### Note
110/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
111/// will return an error the next time it is used.
112#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
113pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
114    conn: Option<C>,
115    response: CopyResponse,
116}
117
118impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
119    async fn begin(conn: C, statement: &str) -> Result<Self> {
120        let mut conn = Self::start_copy(conn, statement).await?;
121        match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
122            Ok(response) => Ok(PgCopyIn {
123                conn: Some(conn),
124                response,
125            }),
126            Err(e) => {
127                conn.stream
128                    .send(CopyFail::new("failed to start COPY"))
129                    .await?;
130                conn.stream
131                    .recv_expect::<()>(MessageFormat::ReadyForQuery)
132                    .await?;
133                Err(e)
134            }
135        }
136    }
137
138    async fn start_copy(mut conn: C, statement: &str) -> Result<C> {
139        conn.wait_until_ready().await?;
140        conn.stream.send(Query(statement)).await?;
141        Ok(conn)
142    }
143
144    /// Returns `true` if Postgres is expecting data in text or CSV format.
145    pub fn is_textual(&self) -> bool {
146        self.response.format == 0
147    }
148
149    /// Returns the number of columns expected in the input.
150    pub fn num_columns(&self) -> usize {
151        #[allow(clippy::cast_sign_loss)]
152        let num_columns = self.response.num_columns as usize;
153        assert_eq!(
154            num_columns,
155            self.response.format_codes.len(),
156            "num_columns does not match format_codes.len()"
157        );
158        self.response.format_codes.len()
159    }
160
161    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
162    ///
163    /// ### Panics
164    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
165    pub fn column_is_textual(&self, column: usize) -> bool {
166        self.response.format_codes[column] == 0
167    }
168
169    /// Send a chunk of `COPY` data.
170    ///
171    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
172    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
173        self.conn
174            .as_deref_mut()
175            .expect("send_data: conn taken")
176            .stream
177            .send(CopyData(data))
178            .await?;
179
180        Ok(self)
181    }
182
183    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
184    ///
185    /// `source` will be read to the end.
186    ///
187    /// ### Note
188    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
189    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
190        // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
191        struct BufGuard<'s>(&'s mut Vec<u8>);
192
193        impl Drop for BufGuard<'_> {
194            fn drop(&mut self) {
195                self.0.clear()
196            }
197        }
198
199        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
200
201        // flush any existing messages in the buffer and clear it
202        conn.stream.flush().await?;
203
204        {
205            let buf_stream = &mut *conn.stream;
206            let stream = &mut buf_stream.stream;
207
208            // ensures the buffer isn't left in an inconsistent state
209            let guard = BufGuard(&mut buf_stream.wbuf);
210
211            let buf: &mut Vec<u8> = guard.0;
212            buf.push(b'd'); // CopyData format code
213            buf.resize(5, 0); // reserve space for the length
214
215            loop {
216                let read = match () {
217                    // Tokio lets us read into the buffer without zeroing first
218                    #[cfg(feature = "_rt-tokio")]
219                    _ if buf.len() != buf.capacity() => {
220                        // in case we have some data in the buffer, which can occur
221                        // if the previous write did not fill the buffer
222                        buf.truncate(5);
223                        source.read_buf(buf).await?
224                    }
225                    _ => {
226                        // should be a no-op unless len != capacity
227                        buf.resize(buf.capacity(), 0);
228                        source.read(&mut buf[5..]).await?
229                    }
230                };
231
232                if read == 0 {
233                    break;
234                }
235
236                let read32 = u32::try_from(read)
237                    .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
238
239                (&mut buf[1..]).put_u32(read32 + 4);
240
241                stream.write_all(&buf[..read + 5]).await?;
242                stream.flush().await?;
243            }
244        }
245
246        Ok(self)
247    }
248
249    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
250    ///
251    /// The given message can be used for indicating the reason for the abort in the database logs.
252    ///
253    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
254    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
255        let mut conn = self
256            .conn
257            .take()
258            .expect("PgCopyIn::fail_with: conn taken illegally");
259
260        conn.stream.send(CopyFail::new(msg)).await?;
261
262        match conn.stream.recv().await {
263            Ok(msg) => Err(err_protocol!(
264                "fail_with: expected ErrorResponse, got: {:?}",
265                msg.format
266            )),
267            Err(Error::Database(e)) => {
268                match e.code() {
269                    Some(Cow::Borrowed("57014")) => {
270                        // postgres abort received error code
271                        conn.stream
272                            .recv_expect::<()>(MessageFormat::ReadyForQuery)
273                            .await?;
274                        Ok(())
275                    }
276                    _ => Err(Error::Database(e)),
277                }
278            }
279            Err(e) => Err(e),
280        }
281    }
282
283    /// Signal that the `COPY` process is complete.
284    ///
285    /// The number of rows affected is returned.
286    pub async fn finish(mut self) -> Result<u64> {
287        let mut conn = self
288            .conn
289            .take()
290            .expect("CopyWriter::finish: conn taken illegally");
291
292        conn.stream.send(CopyDone).await?;
293        let cc: CommandComplete = conn
294            .stream
295            .recv_expect(MessageFormat::CommandComplete)
296            .await?;
297
298        conn.stream
299            .recv_expect::<()>(MessageFormat::ReadyForQuery)
300            .await?;
301
302        Ok(cc.rows_affected())
303    }
304}
305
306impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
307    fn drop(&mut self) {
308        if let Some(mut conn) = self.conn.take() {
309            conn.stream.write(CopyFail::new(
310                "PgCopyIn dropped without calling finish() or fail()",
311            ));
312        }
313    }
314}
315
316async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
317    mut conn: C,
318    statement: &str,
319) -> Result<BoxStream<'c, Result<Bytes>>> {
320    conn.wait_until_ready().await?;
321    conn.stream.send(Query(statement)).await?;
322
323    let _: CopyResponse = conn
324        .stream
325        .recv_expect(MessageFormat::CopyOutResponse)
326        .await?;
327
328    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
329        loop {
330            let msg = conn.stream.recv().await?;
331            match msg.format {
332                MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
333                MessageFormat::CopyDone => {
334                    let _ = msg.decode::<CopyDone>()?;
335                    conn.stream.recv_expect::<CommandComplete>(MessageFormat::CommandComplete).await?;
336                    conn.stream.recv_expect::<()>(MessageFormat::ReadyForQuery).await?;
337                    return Ok(())
338                },
339                _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
340            }
341        }
342    };
343
344    Ok(Box::pin(stream))
345}