sqlx_postgres/copy.rs
1use std::borrow::Cow;
2use std::ops::{Deref, DerefMut};
3
4use futures_core::future::BoxFuture;
5use futures_core::stream::BoxStream;
6
7use sqlx_core::bytes::{BufMut, Bytes};
8
9use crate::connection::PgConnection;
10use crate::error::{Error, Result};
11use crate::ext::async_stream::TryAsyncStream;
12use crate::io::AsyncRead;
13use crate::message::{
14 BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
15 CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
16};
17use crate::pool::{Pool, PoolConnection};
18use crate::Postgres;
19
20impl PgConnection {
21 /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
22 /// to Postgres. This is a more efficient way to import data into Postgres as compared to
23 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
24 ///
25 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
26 /// returned.
27 ///
28 /// Command examples and accepted formats for `COPY` data are shown here:
29 /// <https://www.postgresql.org/docs/current/sql-copy.html>
30 ///
31 /// ### Note
32 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
33 /// will return an error the next time it is used.
34 pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
35 PgCopyIn::begin(self, statement).await
36 }
37
38 /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
39 /// from Postgres. This is a more efficient way to export data from Postgres but
40 /// arrives in chunks of one of a few data formats (text/CSV/binary).
41 ///
42 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
43 /// an error is returned.
44 ///
45 /// Note that once this process has begun, unless you read the stream to completion,
46 /// it can only be canceled in two ways:
47 ///
48 /// 1. by closing the connection, or:
49 /// 2. by using another connection to kill the server process that is sending the data as shown
50 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
51 ///
52 /// If you don't read the stream to completion, the next time the connection is used it will
53 /// need to read and discard all the remaining queued data, which could take some time.
54 ///
55 /// Command examples and accepted formats for `COPY` data are shown here:
56 /// <https://www.postgresql.org/docs/current/sql-copy.html>
57 #[allow(clippy::needless_lifetimes)]
58 pub async fn copy_out_raw<'c>(
59 &'c mut self,
60 statement: &str,
61 ) -> Result<BoxStream<'c, Result<Bytes>>> {
62 pg_begin_copy_out(self, statement).await
63 }
64}
65
66/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`][crate::PgPool].
67///
68/// This is a replacement for the inherent methods on `PgPool` which could not exist
69/// once the Postgres driver was moved out into its own crate.
70pub trait PgPoolCopyExt {
71 /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
72 /// This is a more efficient way to import data into Postgres as compared to
73 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
74 ///
75 /// A single connection will be checked out for the duration.
76 ///
77 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
78 /// returned.
79 ///
80 /// Command examples and accepted formats for `COPY` data are shown here:
81 /// <https://www.postgresql.org/docs/current/sql-copy.html>
82 ///
83 /// ### Note
84 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
85 /// will return an error the next time it is used.
86 fn copy_in_raw<'a>(
87 &'a self,
88 statement: &'a str,
89 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
90
91 /// Issue a `COPY TO STDOUT` statement and begin streaming data
92 /// from Postgres. This is a more efficient way to export data from Postgres but
93 /// arrives in chunks of one of a few data formats (text/CSV/binary).
94 ///
95 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
96 /// an error is returned.
97 ///
98 /// Note that once this process has begun, unless you read the stream to completion,
99 /// it can only be canceled in two ways:
100 ///
101 /// 1. by closing the connection, or:
102 /// 2. by using another connection to kill the server process that is sending the data as shown
103 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
104 ///
105 /// If you don't read the stream to completion, the next time the connection is used it will
106 /// need to read and discard all the remaining queued data, which could take some time.
107 ///
108 /// Command examples and accepted formats for `COPY` data are shown here:
109 /// <https://www.postgresql.org/docs/current/sql-copy.html>
110 fn copy_out_raw<'a>(
111 &'a self,
112 statement: &'a str,
113 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
114}
115
116impl PgPoolCopyExt for Pool<Postgres> {
117 fn copy_in_raw<'a>(
118 &'a self,
119 statement: &'a str,
120 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
121 Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
122 }
123
124 fn copy_out_raw<'a>(
125 &'a self,
126 statement: &'a str,
127 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
128 Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
129 }
130}
131
132// (1 GiB - 1) - 1 - length prefix (4 bytes)
133pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4;
134
135/// A connection in streaming `COPY FROM STDIN` mode.
136///
137/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
138///
139/// ### Note
140/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
141/// will return an error the next time it is used.
142#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
143pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
144 conn: Option<C>,
145 response: CopyResponseData,
146}
147
148impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
149 async fn begin(mut conn: C, statement: &str) -> Result<Self> {
150 conn.wait_until_ready().await?;
151 conn.inner.stream.send(Query(statement)).await?;
152
153 let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
154 Ok(res) => res.0,
155 Err(e) => {
156 conn.inner.stream.recv().await?;
157 return Err(e);
158 }
159 };
160
161 Ok(PgCopyIn {
162 conn: Some(conn),
163 response,
164 })
165 }
166
167 /// Returns `true` if Postgres is expecting data in text or CSV format.
168 pub fn is_textual(&self) -> bool {
169 self.response.format == 0
170 }
171
172 /// Returns the number of columns expected in the input.
173 pub fn num_columns(&self) -> usize {
174 assert_eq!(
175 self.response.num_columns.unsigned_abs() as usize,
176 self.response.format_codes.len(),
177 "num_columns does not match format_codes.len()"
178 );
179 self.response.format_codes.len()
180 }
181
182 /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
183 ///
184 /// ### Panics
185 /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
186 pub fn column_is_textual(&self, column: usize) -> bool {
187 self.response.format_codes[column] == 0
188 }
189
190 /// Send a chunk of `COPY` data.
191 ///
192 /// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6
193 /// bytes) and may be partially sent if this call is cancelled.
194 ///
195 /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
196 pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
197 for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
198 self.conn
199 .as_deref_mut()
200 .expect("send_data: conn taken")
201 .inner
202 .stream
203 .send(CopyData(chunk))
204 .await?;
205 }
206
207 Ok(self)
208 }
209
210 /// Copy data directly from `source` to the database without requiring an intermediate buffer.
211 ///
212 /// `source` will be read to the end.
213 ///
214 /// ### Note: Completion Step Required
215 /// You must still call either [Self::finish] or [Self::abort] to complete the process.
216 ///
217 /// ### Note: Runtime Features
218 /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
219 /// depending on which runtime feature is used.
220 ///
221 /// The runtime features _used_ to be mutually exclusive, but are no longer.
222 /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
223 /// takes precedent.
224 pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
225 let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
226 loop {
227 let buf = conn.inner.stream.write_buffer_mut();
228
229 // Write the CopyData format code and reserve space for the length.
230 // This may end up sending an empty `CopyData` packet if, after this point,
231 // we get canceled or read 0 bytes, but that should be fine.
232 buf.put_slice(b"d\0\0\0\x04");
233
234 let read = buf.read_from(&mut source).await?;
235
236 if read == 0 {
237 break;
238 }
239
240 // Write the length
241 let read32 = i32::try_from(read)
242 .map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;
243
244 (&mut buf.get_mut()[1..]).put_i32(read32 + 4);
245
246 conn.inner.stream.flush().await?;
247 }
248
249 Ok(self)
250 }
251
252 /// Signal that the `COPY` process should be aborted and any data received should be discarded.
253 ///
254 /// The given message can be used for indicating the reason for the abort in the database logs.
255 ///
256 /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
257 pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
258 let mut conn = self
259 .conn
260 .take()
261 .expect("PgCopyIn::fail_with: conn taken illegally");
262
263 conn.inner.stream.send(CopyFail::new(msg)).await?;
264
265 match conn.inner.stream.recv().await {
266 Ok(msg) => Err(err_protocol!(
267 "fail_with: expected ErrorResponse, got: {:?}",
268 msg.format
269 )),
270 Err(Error::Database(e)) => {
271 match e.code() {
272 Some(Cow::Borrowed("57014")) => {
273 // postgres abort received error code
274 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
275 Ok(())
276 }
277 _ => Err(Error::Database(e)),
278 }
279 }
280 Err(e) => Err(e),
281 }
282 }
283
284 /// Signal that the `COPY` process is complete.
285 ///
286 /// The number of rows affected is returned.
287 pub async fn finish(mut self) -> Result<u64> {
288 let mut conn = self
289 .conn
290 .take()
291 .expect("CopyWriter::finish: conn taken illegally");
292
293 conn.inner.stream.send(CopyDone).await?;
294 let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
295 Ok(cc) => cc,
296 Err(e) => {
297 conn.inner.stream.recv().await?;
298 return Err(e);
299 }
300 };
301
302 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
303
304 Ok(cc.rows_affected())
305 }
306}
307
308impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
309 fn drop(&mut self) {
310 if let Some(mut conn) = self.conn.take() {
311 conn.inner
312 .stream
313 .write_msg(CopyFail::new(
314 "PgCopyIn dropped without calling finish() or fail()",
315 ))
316 .expect("BUG: PgCopyIn abort message should not be too large");
317 }
318 }
319}
320
321async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
322 mut conn: C,
323 statement: &str,
324) -> Result<BoxStream<'c, Result<Bytes>>> {
325 conn.wait_until_ready().await?;
326 conn.inner.stream.send(Query(statement)).await?;
327
328 let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
329
330 let stream: TryAsyncStream<'c, Bytes> = try_stream! {
331 loop {
332 match conn.inner.stream.recv().await {
333 Err(e) => {
334 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
335 return Err(e);
336 },
337 Ok(msg) => match msg.format {
338 BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
339 BackendMessageFormat::CopyDone => {
340 let _ = msg.decode::<CopyDone>()?;
341 conn.inner.stream.recv_expect::<CommandComplete>().await?;
342 conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
343 return Ok(())
344 },
345 _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
346 }
347 }
348 }
349 };
350
351 Ok(Box::pin(stream))
352}