sqlx_build_trust_postgres/copy.rs
1use futures_core::future::BoxFuture;
2use std::borrow::Cow;
3use std::ops::{Deref, DerefMut};
4
5use futures_core::stream::BoxStream;
6use sqlx_core::bytes::{BufMut, Bytes};
7
8use crate::connection::PgConnection;
9use crate::error::{Error, Result};
10use crate::ext::async_stream::TryAsyncStream;
11use crate::io::{AsyncRead, AsyncReadExt};
12use crate::message::{
13 CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
14};
15use crate::pool::{Pool, PoolConnection};
16use crate::Postgres;
17
18impl PgConnection {
19 /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
20 /// to Postgres. This is a more efficient way to import data into Postgres as compared to
21 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
22 ///
23 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
24 /// returned.
25 ///
26 /// Command examples and accepted formats for `COPY` data are shown here:
27 /// https://www.postgresql.org/docs/current/sql-copy.html
28 ///
29 /// ### Note
30 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
31 /// will return an error the next time it is used.
32 pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
33 PgCopyIn::begin(self, statement).await
34 }
35
36 /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
37 /// from Postgres. This is a more efficient way to export data from Postgres but
38 /// arrives in chunks of one of a few data formats (text/CSV/binary).
39 ///
40 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
41 /// an error is returned.
42 ///
43 /// Note that once this process has begun, unless you read the stream to completion,
44 /// it can only be canceled in two ways:
45 ///
46 /// 1. by closing the connection, or:
47 /// 2. by using another connection to kill the server process that is sending the data as shown
48 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
49 ///
50 /// If you don't read the stream to completion, the next time the connection is used it will
51 /// need to read and discard all the remaining queued data, which could take some time.
52 ///
53 /// Command examples and accepted formats for `COPY` data are shown here:
54 /// https://www.postgresql.org/docs/current/sql-copy.html
55 #[allow(clippy::needless_lifetimes)]
56 pub async fn copy_out_raw<'c>(
57 &'c mut self,
58 statement: &str,
59 ) -> Result<BoxStream<'c, Result<Bytes>>> {
60 pg_begin_copy_out(self, statement).await
61 }
62}
63
64/// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`].
65///
66/// This is a replacement for the inherent methods on `PgPool` which could not exist
67/// once the Postgres driver was moved out into its own crate.
68pub trait PgPoolCopyExt {
69 /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
70 /// This is a more efficient way to import data into Postgres as compared to
71 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
72 ///
73 /// A single connection will be checked out for the duration.
74 ///
75 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
76 /// returned.
77 ///
78 /// Command examples and accepted formats for `COPY` data are shown here:
79 /// https://www.postgresql.org/docs/current/sql-copy.html
80 ///
81 /// ### Note
82 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
83 /// will return an error the next time it is used.
84 fn copy_in_raw<'a>(
85 &'a self,
86 statement: &'a str,
87 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>>;
88
89 /// Issue a `COPY TO STDOUT` statement and begin streaming data
90 /// from Postgres. This is a more efficient way to export data from Postgres but
91 /// arrives in chunks of one of a few data formats (text/CSV/binary).
92 ///
93 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
94 /// an error is returned.
95 ///
96 /// Note that once this process has begun, unless you read the stream to completion,
97 /// it can only be canceled in two ways:
98 ///
99 /// 1. by closing the connection, or:
100 /// 2. by using another connection to kill the server process that is sending the data as shown
101 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
102 ///
103 /// If you don't read the stream to completion, the next time the connection is used it will
104 /// need to read and discard all the remaining queued data, which could take some time.
105 ///
106 /// Command examples and accepted formats for `COPY` data are shown here:
107 /// https://www.postgresql.org/docs/current/sql-copy.html
108 fn copy_out_raw<'a>(
109 &'a self,
110 statement: &'a str,
111 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>>;
112}
113
114impl PgPoolCopyExt for Pool<Postgres> {
115 fn copy_in_raw<'a>(
116 &'a self,
117 statement: &'a str,
118 ) -> BoxFuture<'a, Result<PgCopyIn<PoolConnection<Postgres>>>> {
119 Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await })
120 }
121
122 fn copy_out_raw<'a>(
123 &'a self,
124 statement: &'a str,
125 ) -> BoxFuture<'a, Result<BoxStream<'static, Result<Bytes>>>> {
126 Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await })
127 }
128}
129
130/// A connection in streaming `COPY FROM STDIN` mode.
131///
132/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
133///
134/// ### Note
135/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
136/// will return an error the next time it is used.
137#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
138pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
139 conn: Option<C>,
140 response: CopyResponse,
141}
142
143impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
144 async fn begin(mut conn: C, statement: &str) -> Result<Self> {
145 conn.wait_until_ready().await?;
146 conn.stream.send(Query(statement)).await?;
147
148 let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
149 Ok(res) => res,
150 Err(e) => {
151 conn.stream.recv().await?;
152 return Err(e);
153 }
154 };
155
156 Ok(PgCopyIn {
157 conn: Some(conn),
158 response,
159 })
160 }
161
162 /// Returns `true` if Postgres is expecting data in text or CSV format.
163 pub fn is_textual(&self) -> bool {
164 self.response.format == 0
165 }
166
167 /// Returns the number of columns expected in the input.
168 pub fn num_columns(&self) -> usize {
169 assert_eq!(
170 self.response.num_columns as usize,
171 self.response.format_codes.len(),
172 "num_columns does not match format_codes.len()"
173 );
174 self.response.format_codes.len()
175 }
176
177 /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
178 ///
179 /// ### Panics
180 /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
181 pub fn column_is_textual(&self, column: usize) -> bool {
182 self.response.format_codes[column] == 0
183 }
184
185 /// Send a chunk of `COPY` data.
186 ///
187 /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
188 pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
189 self.conn
190 .as_deref_mut()
191 .expect("send_data: conn taken")
192 .stream
193 .send(CopyData(data))
194 .await?;
195
196 Ok(self)
197 }
198
199 /// Copy data directly from `source` to the database without requiring an intermediate buffer.
200 ///
201 /// `source` will be read to the end.
202 ///
203 /// ### Note: Completion Step Required
204 /// You must still call either [Self::finish] or [Self::abort] to complete the process.
205 ///
206 /// ### Note: Runtime Features
207 /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
208 /// depending on which runtime feature is used.
209 ///
210 /// The runtime features _used_ to be mutually exclusive, but are no longer.
211 /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
212 /// takes precedent.
213 pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
214 // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
215 struct BufGuard<'s>(&'s mut Vec<u8>);
216
217 impl Drop for BufGuard<'_> {
218 fn drop(&mut self) {
219 self.0.clear()
220 }
221 }
222
223 let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
224 loop {
225 let buf = conn.stream.write_buffer_mut();
226
227 // CopyData format code and reserved space for length
228 buf.put_slice(b"d\0\0\0\x04");
229
230 let read = match () {
231 // Tokio lets us read into the buffer without zeroing first
232 #[cfg(feature = "_rt-tokio")]
233 _ => source.read_buf(buf.buf_mut()).await?,
234 #[cfg(not(feature = "_rt-tokio"))]
235 _ => source.read(buf.init_remaining_mut()).await?,
236 };
237
238 if read == 0 {
239 // This will end up sending an empty `CopyData` packet but that should be fine.
240 break;
241 }
242
243 buf.advance(read);
244
245 // Write the length
246 let read32 = u32::try_from(read)
247 .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
248
249 (&mut buf.get_mut()[1..]).put_u32(read32 + 4);
250
251 conn.stream.flush().await?;
252 }
253
254 Ok(self)
255 }
256
257 /// Signal that the `COPY` process should be aborted and any data received should be discarded.
258 ///
259 /// The given message can be used for indicating the reason for the abort in the database logs.
260 ///
261 /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
262 pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
263 let mut conn = self
264 .conn
265 .take()
266 .expect("PgCopyIn::fail_with: conn taken illegally");
267
268 conn.stream.send(CopyFail::new(msg)).await?;
269
270 match conn.stream.recv().await {
271 Ok(msg) => Err(err_protocol!(
272 "fail_with: expected ErrorResponse, got: {:?}",
273 msg.format
274 )),
275 Err(Error::Database(e)) => {
276 match e.code() {
277 Some(Cow::Borrowed("57014")) => {
278 // postgres abort received error code
279 conn.stream
280 .recv_expect(MessageFormat::ReadyForQuery)
281 .await?;
282 Ok(())
283 }
284 _ => Err(Error::Database(e)),
285 }
286 }
287 Err(e) => Err(e),
288 }
289 }
290
291 /// Signal that the `COPY` process is complete.
292 ///
293 /// The number of rows affected is returned.
294 pub async fn finish(mut self) -> Result<u64> {
295 let mut conn = self
296 .conn
297 .take()
298 .expect("CopyWriter::finish: conn taken illegally");
299
300 conn.stream.send(CopyDone).await?;
301 let cc: CommandComplete = match conn
302 .stream
303 .recv_expect(MessageFormat::CommandComplete)
304 .await
305 {
306 Ok(cc) => cc,
307 Err(e) => {
308 conn.stream.recv().await?;
309 return Err(e);
310 }
311 };
312
313 conn.stream
314 .recv_expect(MessageFormat::ReadyForQuery)
315 .await?;
316
317 Ok(cc.rows_affected())
318 }
319}
320
321impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
322 fn drop(&mut self) {
323 if let Some(mut conn) = self.conn.take() {
324 conn.stream.write(CopyFail::new(
325 "PgCopyIn dropped without calling finish() or fail()",
326 ));
327 }
328 }
329}
330
331async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
332 mut conn: C,
333 statement: &str,
334) -> Result<BoxStream<'c, Result<Bytes>>> {
335 conn.wait_until_ready().await?;
336 conn.stream.send(Query(statement)).await?;
337
338 let _: CopyResponse = conn
339 .stream
340 .recv_expect(MessageFormat::CopyOutResponse)
341 .await?;
342
343 let stream: TryAsyncStream<'c, Bytes> = try_stream! {
344 loop {
345 let msg = conn.stream.recv().await?;
346 match msg.format {
347 MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
348 MessageFormat::CopyDone => {
349 let _ = msg.decode::<CopyDone>()?;
350 conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
351 conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
352 return Ok(())
353 },
354 _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
355 }
356 }
357 };
358
359 Ok(Box::pin(stream))
360}