sqlx_core_oldapi/postgres/copy.rs
1use crate::error::{Error, Result};
2use crate::ext::async_stream::TryAsyncStream;
3use crate::pool::{Pool, PoolConnection};
4use crate::postgres::connection::PgConnection;
5use crate::postgres::message::{
6 CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
7};
8use crate::postgres::Postgres;
9use bytes::{BufMut, Bytes};
10use futures_core::stream::BoxStream;
11use smallvec::alloc::borrow::Cow;
12use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13use std::ops::{Deref, DerefMut};
14
15impl PgConnection {
16 /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17 /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19 ///
20 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21 /// returned.
22 ///
23 /// Command examples and accepted formats for `COPY` data are shown here:
24 /// https://www.postgresql.org/docs/current/sql-copy.html
25 ///
26 /// ### Note
27 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28 /// will return an error the next time it is used.
29 pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30 PgCopyIn::begin(self, statement).await
31 }
32
33 /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34 /// from Postgres. This is a more efficient way to export data from Postgres but
35 /// arrives in chunks of one of a few data formats (text/CSV/binary).
36 ///
37 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38 /// an error is returned.
39 ///
40 /// Note that once this process has begun, unless you read the stream to completion,
41 /// it can only be canceled in two ways:
42 ///
43 /// 1. by closing the connection, or:
44 /// 2. by using another connection to kill the server process that is sending the data as shown
45 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46 ///
47 /// If you don't read the stream to completion, the next time the connection is used it will
48 /// need to read and discard all the remaining queued data, which could take some time.
49 ///
50 /// Command examples and accepted formats for `COPY` data are shown here:
51 /// https://www.postgresql.org/docs/current/sql-copy.html
52 #[allow(clippy::needless_lifetimes)]
53 pub async fn copy_out_raw<'c>(
54 &'c mut self,
55 statement: &str,
56 ) -> Result<BoxStream<'c, Result<Bytes>>> {
57 pg_begin_copy_out(self, statement).await
58 }
59}
60
61impl Pool<Postgres> {
62 /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
63 /// This is a more efficient way to import data into Postgres as compared to
64 /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
65 ///
66 /// A single connection will be checked out for the duration.
67 ///
68 /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
69 /// returned.
70 ///
71 /// Command examples and accepted formats for `COPY` data are shown here:
72 /// https://www.postgresql.org/docs/current/sql-copy.html
73 ///
74 /// ### Note
75 /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
76 /// will return an error the next time it is used.
77 pub async fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
78 PgCopyIn::begin(self.acquire().await?, statement).await
79 }
80
81 /// Issue a `COPY TO STDOUT` statement and begin streaming data
82 /// from Postgres. This is a more efficient way to export data from Postgres but
83 /// arrives in chunks of one of a few data formats (text/CSV/binary).
84 ///
85 /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
86 /// an error is returned.
87 ///
88 /// Note that once this process has begun, unless you read the stream to completion,
89 /// it can only be canceled in two ways:
90 ///
91 /// 1. by closing the connection, or:
92 /// 2. by using another connection to kill the server process that is sending the data as shown
93 /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
94 ///
95 /// If you don't read the stream to completion, the next time the connection is used it will
96 /// need to read and discard all the remaining queued data, which could take some time.
97 ///
98 /// Command examples and accepted formats for `COPY` data are shown here:
99 /// https://www.postgresql.org/docs/current/sql-copy.html
100 pub async fn copy_out_raw(&self, statement: &str) -> Result<BoxStream<'static, Result<Bytes>>> {
101 pg_begin_copy_out(self.acquire().await?, statement).await
102 }
103}
104
105/// A connection in streaming `COPY FROM STDIN` mode.
106///
107/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
108///
109/// ### Note
110/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
111/// will return an error the next time it is used.
112#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
113pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
114 conn: Option<C>,
115 response: CopyResponse,
116}
117
118impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
119 async fn begin(conn: C, statement: &str) -> Result<Self> {
120 let mut conn = Self::start_copy(conn, statement).await?;
121 match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
122 Ok(response) => Ok(PgCopyIn {
123 conn: Some(conn),
124 response,
125 }),
126 Err(e) => {
127 conn.stream
128 .send(CopyFail::new("failed to start COPY"))
129 .await?;
130 conn.stream
131 .recv_expect::<()>(MessageFormat::ReadyForQuery)
132 .await?;
133 Err(e)
134 }
135 }
136 }
137
138 async fn start_copy(mut conn: C, statement: &str) -> Result<C> {
139 conn.wait_until_ready().await?;
140 conn.stream.send(Query(statement)).await?;
141 Ok(conn)
142 }
143
144 /// Returns `true` if Postgres is expecting data in text or CSV format.
145 pub fn is_textual(&self) -> bool {
146 self.response.format == 0
147 }
148
149 /// Returns the number of columns expected in the input.
150 pub fn num_columns(&self) -> usize {
151 #[allow(clippy::cast_sign_loss)]
152 let num_columns = self.response.num_columns as usize;
153 assert_eq!(
154 num_columns,
155 self.response.format_codes.len(),
156 "num_columns does not match format_codes.len()"
157 );
158 self.response.format_codes.len()
159 }
160
161 /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
162 ///
163 /// ### Panics
164 /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
165 pub fn column_is_textual(&self, column: usize) -> bool {
166 self.response.format_codes[column] == 0
167 }
168
169 /// Send a chunk of `COPY` data.
170 ///
171 /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
172 pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
173 self.conn
174 .as_deref_mut()
175 .expect("send_data: conn taken")
176 .stream
177 .send(CopyData(data))
178 .await?;
179
180 Ok(self)
181 }
182
183 /// Copy data directly from `source` to the database without requiring an intermediate buffer.
184 ///
185 /// `source` will be read to the end.
186 ///
187 /// ### Note
188 /// You must still call either [Self::finish] or [Self::abort] to complete the process.
189 pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
190 // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
191 struct BufGuard<'s>(&'s mut Vec<u8>);
192
193 impl Drop for BufGuard<'_> {
194 fn drop(&mut self) {
195 self.0.clear()
196 }
197 }
198
199 let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
200
201 // flush any existing messages in the buffer and clear it
202 conn.stream.flush().await?;
203
204 {
205 let buf_stream = &mut *conn.stream;
206 let stream = &mut buf_stream.stream;
207
208 // ensures the buffer isn't left in an inconsistent state
209 let guard = BufGuard(&mut buf_stream.wbuf);
210
211 let buf: &mut Vec<u8> = guard.0;
212 buf.push(b'd'); // CopyData format code
213 buf.resize(5, 0); // reserve space for the length
214
215 loop {
216 let read = match () {
217 // Tokio lets us read into the buffer without zeroing first
218 #[cfg(feature = "_rt-tokio")]
219 _ if buf.len() != buf.capacity() => {
220 // in case we have some data in the buffer, which can occur
221 // if the previous write did not fill the buffer
222 buf.truncate(5);
223 source.read_buf(buf).await?
224 }
225 _ => {
226 // should be a no-op unless len != capacity
227 buf.resize(buf.capacity(), 0);
228 source.read(&mut buf[5..]).await?
229 }
230 };
231
232 if read == 0 {
233 break;
234 }
235
236 let read32 = u32::try_from(read)
237 .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
238
239 (&mut buf[1..]).put_u32(read32 + 4);
240
241 stream.write_all(&buf[..read + 5]).await?;
242 stream.flush().await?;
243 }
244 }
245
246 Ok(self)
247 }
248
249 /// Signal that the `COPY` process should be aborted and any data received should be discarded.
250 ///
251 /// The given message can be used for indicating the reason for the abort in the database logs.
252 ///
253 /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
254 pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
255 let mut conn = self
256 .conn
257 .take()
258 .expect("PgCopyIn::fail_with: conn taken illegally");
259
260 conn.stream.send(CopyFail::new(msg)).await?;
261
262 match conn.stream.recv().await {
263 Ok(msg) => Err(err_protocol!(
264 "fail_with: expected ErrorResponse, got: {:?}",
265 msg.format
266 )),
267 Err(Error::Database(e)) => {
268 match e.code() {
269 Some(Cow::Borrowed("57014")) => {
270 // postgres abort received error code
271 conn.stream
272 .recv_expect::<()>(MessageFormat::ReadyForQuery)
273 .await?;
274 Ok(())
275 }
276 _ => Err(Error::Database(e)),
277 }
278 }
279 Err(e) => Err(e),
280 }
281 }
282
283 /// Signal that the `COPY` process is complete.
284 ///
285 /// The number of rows affected is returned.
286 pub async fn finish(mut self) -> Result<u64> {
287 let mut conn = self
288 .conn
289 .take()
290 .expect("CopyWriter::finish: conn taken illegally");
291
292 conn.stream.send(CopyDone).await?;
293 let cc: CommandComplete = conn
294 .stream
295 .recv_expect(MessageFormat::CommandComplete)
296 .await?;
297
298 conn.stream
299 .recv_expect::<()>(MessageFormat::ReadyForQuery)
300 .await?;
301
302 Ok(cc.rows_affected())
303 }
304}
305
306impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
307 fn drop(&mut self) {
308 if let Some(mut conn) = self.conn.take() {
309 conn.stream.write(CopyFail::new(
310 "PgCopyIn dropped without calling finish() or fail()",
311 ));
312 }
313 }
314}
315
316async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
317 mut conn: C,
318 statement: &str,
319) -> Result<BoxStream<'c, Result<Bytes>>> {
320 conn.wait_until_ready().await?;
321 conn.stream.send(Query(statement)).await?;
322
323 let _: CopyResponse = conn
324 .stream
325 .recv_expect(MessageFormat::CopyOutResponse)
326 .await?;
327
328 let stream: TryAsyncStream<'c, Bytes> = try_stream! {
329 loop {
330 let msg = conn.stream.recv().await?;
331 match msg.format {
332 MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
333 MessageFormat::CopyDone => {
334 let _ = msg.decode::<CopyDone>()?;
335 conn.stream.recv_expect::<CommandComplete>(MessageFormat::CommandComplete).await?;
336 conn.stream.recv_expect::<()>(MessageFormat::ReadyForQuery).await?;
337 return Ok(())
338 },
339 _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
340 }
341 }
342 };
343
344 Ok(Box::pin(stream))
345}