sqlx_core_oldapi/postgres/
copy.rs

1use crate::error::{Error, Result};
2use crate::ext::async_stream::TryAsyncStream;
3use crate::pool::{Pool, PoolConnection};
4use crate::postgres::connection::PgConnection;
5use crate::postgres::message::{
6    CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
7};
8use crate::postgres::Postgres;
9use bytes::{BufMut, Bytes};
10use futures_core::stream::BoxStream;
11use smallvec::alloc::borrow::Cow;
12use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use std::ops::{Deref, DerefMut};
14
15impl PgConnection {
16    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19    ///
20    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21    /// returned.
22    ///
23    /// Command examples and accepted formats for `COPY` data are shown here:
24    /// https://www.postgresql.org/docs/current/sql-copy.html
25    ///
26    /// ### Note
27    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28    /// will return an error the next time it is used.
29    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30        PgCopyIn::begin(self, statement).await
31    }
32
33    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34    /// from Postgres. This is a more efficient way to export data from Postgres but
35    /// arrives in chunks of one of a few data formats (text/CSV/binary).
36    ///
37    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38    /// an error is returned.
39    ///
40    /// Note that once this process has begun, unless you read the stream to completion,
41    /// it can only be canceled in two ways:
42    ///
43    /// 1. by closing the connection, or:
44    /// 2. by using another connection to kill the server process that is sending the data as shown
45    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46    ///
47    /// If you don't read the stream to completion, the next time the connection is used it will
48    /// need to read and discard all the remaining queued data, which could take some time.
49    ///
50    /// Command examples and accepted formats for `COPY` data are shown here:
51    /// https://www.postgresql.org/docs/current/sql-copy.html
52    #[allow(clippy::needless_lifetimes)]
53    pub async fn copy_out_raw<'c>(
54        &'c mut self,
55        statement: &str,
56    ) -> Result<BoxStream<'c, Result<Bytes>>> {
57        pg_begin_copy_out(self, statement).await
58    }
59}
60
61impl Pool<Postgres> {
62    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
63    /// This is a more efficient way to import data into Postgres as compared to
64    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
65    ///
66    /// A single connection will be checked out for the duration.
67    ///
68    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
69    /// returned.
70    ///
71    /// Command examples and accepted formats for `COPY` data are shown here:
72    /// https://www.postgresql.org/docs/current/sql-copy.html
73    ///
74    /// ### Note
75    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
76    /// will return an error the next time it is used.
77    pub async fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
78        PgCopyIn::begin(self.acquire().await?, statement).await
79    }
80
81    /// Issue a `COPY TO STDOUT` statement and begin streaming data
82    /// from Postgres. This is a more efficient way to export data from Postgres but
83    /// arrives in chunks of one of a few data formats (text/CSV/binary).
84    ///
85    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
86    /// an error is returned.
87    ///
88    /// Note that once this process has begun, unless you read the stream to completion,
89    /// it can only be canceled in two ways:
90    ///
91    /// 1. by closing the connection, or:
92    /// 2. by using another connection to kill the server process that is sending the data as shown
93    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
94    ///
95    /// If you don't read the stream to completion, the next time the connection is used it will
96    /// need to read and discard all the remaining queued data, which could take some time.
97    ///
98    /// Command examples and accepted formats for `COPY` data are shown here:
99    /// https://www.postgresql.org/docs/current/sql-copy.html
100    pub async fn copy_out_raw(&self, statement: &str) -> Result<BoxStream<'static, Result<Bytes>>> {
101        pg_begin_copy_out(self.acquire().await?, statement).await
102    }
103}
104
105/// A connection in streaming `COPY FROM STDIN` mode.
106///
107/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
108///
109/// ### Note
110/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
111/// will return an error the next time it is used.
112#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
113pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
114    conn: Option<C>,
115    response: CopyResponse,
116}
117
118impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
119    async fn begin(conn: C, statement: &str) -> Result<Self> {
120        let mut conn = Self::start_copy(conn, statement).await?;
121        match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
122            Ok(response) => Ok(PgCopyIn {
123                conn: Some(conn),
124                response,
125            }),
126            Err(e) => {
127                conn.stream
128                    .send(CopyFail::new("failed to start COPY"))
129                    .await?;
130                conn.stream
131                    .recv_expect::<()>(MessageFormat::ReadyForQuery)
132                    .await?;
133                Err(e)
134            }
135        }
136    }
137
138    async fn start_copy(mut conn: C, statement: &str) -> Result<C> {
139        conn.wait_until_ready().await?;
140        conn.stream.send(Query(statement)).await?;
141        Ok(conn)
142    }
143
144    /// Returns `true` if Postgres is expecting data in text or CSV format.
145    pub fn is_textual(&self) -> bool {
146        self.response.format == 0
147    }
148
149    /// Returns the number of columns expected in the input.
150    pub fn num_columns(&self) -> usize {
151        assert_eq!(
152            self.response.num_columns as usize,
153            self.response.format_codes.len(),
154            "num_columns does not match format_codes.len()"
155        );
156        self.response.format_codes.len()
157    }
158
159    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
160    ///
161    /// ### Panics
162    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
163    pub fn column_is_textual(&self, column: usize) -> bool {
164        self.response.format_codes[column] == 0
165    }
166
167    /// Send a chunk of `COPY` data.
168    ///
169    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
170    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
171        self.conn
172            .as_deref_mut()
173            .expect("send_data: conn taken")
174            .stream
175            .send(CopyData(data))
176            .await?;
177
178        Ok(self)
179    }
180
181    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
182    ///
183    /// `source` will be read to the end.
184    ///
185    /// ### Note
186    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
187    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
188        // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
189        struct BufGuard<'s>(&'s mut Vec<u8>);
190
191        impl Drop for BufGuard<'_> {
192            fn drop(&mut self) {
193                self.0.clear()
194            }
195        }
196
197        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
198
199        // flush any existing messages in the buffer and clear it
200        conn.stream.flush().await?;
201
202        {
203            let buf_stream = &mut *conn.stream;
204            let stream = &mut buf_stream.stream;
205
206            // ensures the buffer isn't left in an inconsistent state
207            let mut guard = BufGuard(&mut buf_stream.wbuf);
208
209            let buf: &mut Vec<u8> = &mut guard.0;
210            buf.push(b'd'); // CopyData format code
211            buf.resize(5, 0); // reserve space for the length
212
213            loop {
214                let read = match () {
215                    // Tokio lets us read into the buffer without zeroing first
216                    #[cfg(feature = "_rt-tokio")]
217                    _ if buf.len() != buf.capacity() => {
218                        // in case we have some data in the buffer, which can occur
219                        // if the previous write did not fill the buffer
220                        buf.truncate(5);
221                        source.read_buf(buf).await?
222                    }
223                    _ => {
224                        // should be a no-op unless len != capacity
225                        buf.resize(buf.capacity(), 0);
226                        source.read(&mut buf[5..]).await?
227                    }
228                };
229
230                if read == 0 {
231                    break;
232                }
233
234                let read32 = u32::try_from(read)
235                    .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
236
237                (&mut buf[1..]).put_u32(read32 + 4);
238
239                stream.write_all(&buf[..read + 5]).await?;
240                stream.flush().await?;
241            }
242        }
243
244        Ok(self)
245    }
246
247    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
248    ///
249    /// The given message can be used for indicating the reason for the abort in the database logs.
250    ///
251    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
252    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
253        let mut conn = self
254            .conn
255            .take()
256            .expect("PgCopyIn::fail_with: conn taken illegally");
257
258        conn.stream.send(CopyFail::new(msg)).await?;
259
260        match conn.stream.recv().await {
261            Ok(msg) => Err(err_protocol!(
262                "fail_with: expected ErrorResponse, got: {:?}",
263                msg.format
264            )),
265            Err(Error::Database(e)) => {
266                match e.code() {
267                    Some(Cow::Borrowed("57014")) => {
268                        // postgres abort received error code
269                        conn.stream
270                            .recv_expect::<()>(MessageFormat::ReadyForQuery)
271                            .await?;
272                        Ok(())
273                    }
274                    _ => Err(Error::Database(e)),
275                }
276            }
277            Err(e) => Err(e),
278        }
279    }
280
281    /// Signal that the `COPY` process is complete.
282    ///
283    /// The number of rows affected is returned.
284    pub async fn finish(mut self) -> Result<u64> {
285        let mut conn = self
286            .conn
287            .take()
288            .expect("CopyWriter::finish: conn taken illegally");
289
290        conn.stream.send(CopyDone).await?;
291        let cc: CommandComplete = conn
292            .stream
293            .recv_expect(MessageFormat::CommandComplete)
294            .await?;
295
296        conn.stream
297            .recv_expect::<()>(MessageFormat::ReadyForQuery)
298            .await?;
299
300        Ok(cc.rows_affected())
301    }
302}
303
304impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
305    fn drop(&mut self) {
306        if let Some(mut conn) = self.conn.take() {
307            conn.stream.write(CopyFail::new(
308                "PgCopyIn dropped without calling finish() or fail()",
309            ));
310        }
311    }
312}
313
314async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
315    mut conn: C,
316    statement: &str,
317) -> Result<BoxStream<'c, Result<Bytes>>> {
318    conn.wait_until_ready().await?;
319    conn.stream.send(Query(statement)).await?;
320
321    let _: CopyResponse = conn
322        .stream
323        .recv_expect(MessageFormat::CopyOutResponse)
324        .await?;
325
326    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
327        loop {
328            let msg = conn.stream.recv().await?;
329            match msg.format {
330                MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
331                MessageFormat::CopyDone => {
332                    let _ = msg.decode::<CopyDone>()?;
333                    conn.stream.recv_expect::<CommandComplete>(MessageFormat::CommandComplete).await?;
334                    conn.stream.recv_expect::<()>(MessageFormat::ReadyForQuery).await?;
335                    return Ok(())
336                },
337                _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
338            }
339        }
340    };
341
342    Ok(Box::pin(stream))
343}