sqlx_core_guts/postgres/
copy.rs

1use crate::error::{Error, Result};
2use crate::ext::async_stream::TryAsyncStream;
3use crate::pool::{Pool, PoolConnection};
4use crate::postgres::connection::PgConnection;
5use crate::postgres::message::{
6    CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
7};
8use crate::postgres::Postgres;
9use bytes::{BufMut, Bytes};
10use futures_core::stream::BoxStream;
11use smallvec::alloc::borrow::Cow;
12use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use std::ops::{Deref, DerefMut};
14
15impl PgConnection {
16    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19    ///
20    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21    /// returned.
22    ///
23    /// Command examples and accepted formats for `COPY` data are shown here:
24    /// https://www.postgresql.org/docs/current/sql-copy.html
25    ///
26    /// ### Note
27    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28    /// will return an error the next time it is used.
29    pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30        PgCopyIn::begin(self, statement).await
31    }
32
33    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34    /// from Postgres. This is a more efficient way to export data from Postgres but
35    /// arrives in chunks of one of a few data formats (text/CSV/binary).
36    ///
37    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38    /// an error is returned.
39    ///
40    /// Note that once this process has begun, unless you read the stream to completion,
41    /// it can only be canceled in two ways:
42    ///
43    /// 1. by closing the connection, or:
44    /// 2. by using another connection to kill the server process that is sending the data as shown
45    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46    ///
47    /// If you don't read the stream to completion, the next time the connection is used it will
48    /// need to read and discard all the remaining queued data, which could take some time.
49    ///
50    /// Command examples and accepted formats for `COPY` data are shown here:
51    /// https://www.postgresql.org/docs/current/sql-copy.html
52    #[allow(clippy::needless_lifetimes)]
53    pub async fn copy_out_raw<'c>(
54        &'c mut self,
55        statement: &str,
56    ) -> Result<BoxStream<'c, Result<Bytes>>> {
57        pg_begin_copy_out(self, statement).await
58    }
59}
60
61impl Pool<Postgres> {
62    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
63    /// This is a more efficient way to import data into Postgres as compared to
64    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
65    ///
66    /// A single connection will be checked out for the duration.
67    ///
68    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
69    /// returned.
70    ///
71    /// Command examples and accepted formats for `COPY` data are shown here:
72    /// https://www.postgresql.org/docs/current/sql-copy.html
73    ///
74    /// ### Note
75    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
76    /// will return an error the next time it is used.
77    pub async fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
78        PgCopyIn::begin(self.acquire().await?, statement).await
79    }
80
81    /// Issue a `COPY TO STDOUT` statement and begin streaming data
82    /// from Postgres. This is a more efficient way to export data from Postgres but
83    /// arrives in chunks of one of a few data formats (text/CSV/binary).
84    ///
85    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
86    /// an error is returned.
87    ///
88    /// Note that once this process has begun, unless you read the stream to completion,
89    /// it can only be canceled in two ways:
90    ///
91    /// 1. by closing the connection, or:
92    /// 2. by using another connection to kill the server process that is sending the data as shown
93    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
94    ///
95    /// If you don't read the stream to completion, the next time the connection is used it will
96    /// need to read and discard all the remaining queued data, which could take some time.
97    ///
98    /// Command examples and accepted formats for `COPY` data are shown here:
99    /// https://www.postgresql.org/docs/current/sql-copy.html
100    pub async fn copy_out_raw(&self, statement: &str) -> Result<BoxStream<'static, Result<Bytes>>> {
101        pg_begin_copy_out(self.acquire().await?, statement).await
102    }
103}
104
105/// A connection in streaming `COPY FROM STDIN` mode.
106///
107/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
108///
109/// ### Note
110/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
111/// will return an error the next time it is used.
112#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
113pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
114    conn: Option<C>,
115    response: CopyResponse,
116}
117
118impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
119    async fn begin(mut conn: C, statement: &str) -> Result<Self> {
120        conn.wait_until_ready().await?;
121        conn.stream.send(Query(statement)).await?;
122
123        let response: CopyResponse = conn
124            .stream
125            .recv_expect(MessageFormat::CopyInResponse)
126            .await?;
127
128        Ok(PgCopyIn {
129            conn: Some(conn),
130            response,
131        })
132    }
133
134    /// Returns `true` if Postgres is expecting data in text or CSV format.
135    pub fn is_textual(&self) -> bool {
136        self.response.format == 0
137    }
138
139    /// Returns the number of columns expected in the input.
140    pub fn num_columns(&self) -> usize {
141        assert_eq!(
142            self.response.num_columns as usize,
143            self.response.format_codes.len(),
144            "num_columns does not match format_codes.len()"
145        );
146        self.response.format_codes.len()
147    }
148
149    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
150    ///
151    /// ### Panics
152    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
153    pub fn column_is_textual(&self, column: usize) -> bool {
154        self.response.format_codes[column] == 0
155    }
156
157    /// Send a chunk of `COPY` data.
158    ///
159    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
160    pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
161        self.conn
162            .as_deref_mut()
163            .expect("send_data: conn taken")
164            .stream
165            .send(CopyData(data))
166            .await?;
167
168        Ok(self)
169    }
170
171    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
172    ///
173    /// `source` will be read to the end.
174    ///
175    /// ### Note
176    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
177    pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
178        // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
179        struct BufGuard<'s>(&'s mut Vec<u8>);
180
181        impl Drop for BufGuard<'_> {
182            fn drop(&mut self) {
183                self.0.clear()
184            }
185        }
186
187        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
188
189        // flush any existing messages in the buffer and clear it
190        conn.stream.flush().await?;
191
192        {
193            let buf_stream = &mut *conn.stream;
194            let stream = &mut buf_stream.stream;
195
196            // ensures the buffer isn't left in an inconsistent state
197            let mut guard = BufGuard(&mut buf_stream.wbuf);
198
199            let buf: &mut Vec<u8> = &mut guard.0;
200            buf.push(b'd'); // CopyData format code
201            buf.resize(5, 0); // reserve space for the length
202
203            loop {
204                let read = match () {
205                    // Tokio lets us read into the buffer without zeroing first
206                    #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))]
207                    _ if buf.len() != buf.capacity() => {
208                        // in case we have some data in the buffer, which can occur
209                        // if the previous write did not fill the buffer
210                        buf.truncate(5);
211                        source.read_buf(buf).await?
212                    }
213                    _ => {
214                        // should be a no-op unless len != capacity
215                        buf.resize(buf.capacity(), 0);
216                        source.read(&mut buf[5..]).await?
217                    }
218                };
219
220                if read == 0 {
221                    break;
222                }
223
224                let read32 = u32::try_from(read)
225                    .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
226
227                (&mut buf[1..]).put_u32(read32 + 4);
228
229                stream.write_all(&buf[..read + 5]).await?;
230                stream.flush().await?;
231            }
232        }
233
234        Ok(self)
235    }
236
237    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
238    ///
239    /// The given message can be used for indicating the reason for the abort in the database logs.
240    ///
241    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
242    pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
243        let mut conn = self
244            .conn
245            .take()
246            .expect("PgCopyIn::fail_with: conn taken illegally");
247
248        conn.stream.send(CopyFail::new(msg)).await?;
249
250        match conn.stream.recv().await {
251            Ok(msg) => Err(err_protocol!(
252                "fail_with: expected ErrorResponse, got: {:?}",
253                msg.format
254            )),
255            Err(Error::Database(e)) => {
256                match e.code() {
257                    Some(Cow::Borrowed("57014")) => {
258                        // postgres abort received error code
259                        conn.stream
260                            .recv_expect(MessageFormat::ReadyForQuery)
261                            .await?;
262                        Ok(())
263                    }
264                    _ => Err(Error::Database(e)),
265                }
266            }
267            Err(e) => Err(e),
268        }
269    }
270
271    /// Signal that the `COPY` process is complete.
272    ///
273    /// The number of rows affected is returned.
274    pub async fn finish(mut self) -> Result<u64> {
275        let mut conn = self
276            .conn
277            .take()
278            .expect("CopyWriter::finish: conn taken illegally");
279
280        conn.stream.send(CopyDone).await?;
281        let cc: CommandComplete = conn
282            .stream
283            .recv_expect(MessageFormat::CommandComplete)
284            .await?;
285
286        conn.stream
287            .recv_expect(MessageFormat::ReadyForQuery)
288            .await?;
289
290        Ok(cc.rows_affected())
291    }
292}
293
294impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
295    fn drop(&mut self) {
296        if let Some(mut conn) = self.conn.take() {
297            conn.stream.write(CopyFail::new(
298                "PgCopyIn dropped without calling finish() or fail()",
299            ));
300        }
301    }
302}
303
304async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
305    mut conn: C,
306    statement: &str,
307) -> Result<BoxStream<'c, Result<Bytes>>> {
308    conn.wait_until_ready().await?;
309    conn.stream.send(Query(statement)).await?;
310
311    let _: CopyResponse = conn
312        .stream
313        .recv_expect(MessageFormat::CopyOutResponse)
314        .await?;
315
316    let stream: TryAsyncStream<'c, Bytes> = try_stream! {
317        loop {
318            let msg = conn.stream.recv().await?;
319            match msg.format {
320                MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
321                MessageFormat::CopyDone => {
322                    let _ = msg.decode::<CopyDone>()?;
323                    conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
324                    conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
325                    return Ok(())
326                },
327                _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
328            }
329        }
330    };
331
332    Ok(Box::pin(stream))
333}