tokio_postgres/
binary_copy.rs

1//! Utilities for working with the PostgreSQL binary copy format.
2
3use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
4use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
5use byteorder::{BigEndian, ByteOrder};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use futures_util::{SinkExt, Stream};
8use pin_project_lite::pin_project;
9use postgres_types::BorrowToSql;
10use std::io;
11use std::io::Cursor;
12use std::ops::Range;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{ready, Context, Poll};
16
17const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
18const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
19
20pin_project! {
21    /// A type which serializes rows into the PostgreSQL binary copy format.
22    ///
23    /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
24    pub struct BinaryCopyInWriter {
25        #[pin]
26        sink: CopyInSink<Bytes>,
27        types: Vec<Type>,
28        buf: BytesMut,
29    }
30}
31
32impl BinaryCopyInWriter {
33    /// Creates a new writer which will write rows of the provided types to the provided sink.
34    pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
35        let mut buf = BytesMut::new();
36        buf.put_slice(MAGIC);
37        buf.put_i32(0); // flags
38        buf.put_i32(0); // header extension
39
40        BinaryCopyInWriter {
41            sink,
42            types: types.to_vec(),
43            buf,
44        }
45    }
46
47    /// Writes a single row.
48    ///
49    /// # Panics
50    ///
51    /// Panics if the number of values provided does not match the number expected.
52    pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
53        self.write_raw(slice_iter(values)).await
54    }
55
56    /// A maximally-flexible version of `write`.
57    ///
58    /// # Panics
59    ///
60    /// Panics if the number of values provided does not match the number expected.
61    pub async fn write_raw<P, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
62    where
63        P: BorrowToSql,
64        I: IntoIterator<Item = P>,
65        I::IntoIter: ExactSizeIterator,
66    {
67        let mut this = self.project();
68
69        let values = values.into_iter();
70        assert!(
71            values.len() == this.types.len(),
72            "expected {} values but got {}",
73            this.types.len(),
74            values.len(),
75        );
76
77        this.buf.put_i16(this.types.len() as i16);
78
79        for (i, (value, type_)) in values.zip(this.types).enumerate() {
80            let idx = this.buf.len();
81            this.buf.put_i32(0);
82            let len = match value
83                .borrow_to_sql()
84                .to_sql_checked(type_, this.buf)
85                .map_err(|e| Error::to_sql(e, i))?
86            {
87                IsNull::Yes => -1,
88                IsNull::No => i32::try_from(this.buf.len() - idx - 4)
89                    .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
90            };
91            BigEndian::write_i32(&mut this.buf[idx..], len);
92        }
93
94        if this.buf.len() > 4096 {
95            this.sink.send(this.buf.split().freeze()).await?;
96        }
97
98        Ok(())
99    }
100
101    /// Completes the copy, returning the number of rows added.
102    ///
103    /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
104    pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
105        let mut this = self.project();
106
107        this.buf.put_i16(-1);
108        this.sink.send(this.buf.split().freeze()).await?;
109        this.sink.finish().await
110    }
111}
112
113struct Header {
114    has_oids: bool,
115}
116
117pin_project! {
118    /// A stream of rows deserialized from the PostgreSQL binary copy format.
119    pub struct BinaryCopyOutStream {
120        #[pin]
121        stream: CopyOutStream,
122        types: Arc<Vec<Type>>,
123        header: Option<Header>,
124    }
125}
126
127impl BinaryCopyOutStream {
128    /// Creates a stream from a raw copy out stream and the types of the columns being returned.
129    pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
130        BinaryCopyOutStream {
131            stream,
132            types: Arc::new(types.to_vec()),
133            header: None,
134        }
135    }
136}
137
138impl Stream for BinaryCopyOutStream {
139    type Item = Result<BinaryCopyOutRow, Error>;
140
141    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142        let this = self.project();
143
144        let chunk = match ready!(this.stream.poll_next(cx)) {
145            Some(Ok(chunk)) => chunk,
146            Some(Err(e)) => return Poll::Ready(Some(Err(e))),
147            None => return Poll::Ready(Some(Err(Error::closed()))),
148        };
149        let mut chunk = Cursor::new(chunk);
150
151        let has_oids = match &this.header {
152            Some(header) => header.has_oids,
153            None => {
154                check_remaining(&chunk, HEADER_LEN)?;
155                if !chunk.chunk().starts_with(MAGIC) {
156                    return Poll::Ready(Some(Err(Error::parse(io::Error::new(
157                        io::ErrorKind::InvalidData,
158                        "invalid magic value",
159                    )))));
160                }
161                chunk.advance(MAGIC.len());
162
163                let flags = chunk.get_i32();
164                let has_oids = (flags & (1 << 16)) != 0;
165
166                let header_extension = chunk.get_u32() as usize;
167                check_remaining(&chunk, header_extension)?;
168                chunk.advance(header_extension);
169
170                *this.header = Some(Header { has_oids });
171                has_oids
172            }
173        };
174
175        check_remaining(&chunk, 2)?;
176        let mut len = chunk.get_i16();
177        if len == -1 {
178            return Poll::Ready(None);
179        }
180
181        if has_oids {
182            len += 1;
183        }
184        if len as usize != this.types.len() {
185            return Poll::Ready(Some(Err(Error::parse(io::Error::new(
186                io::ErrorKind::InvalidInput,
187                format!("expected {} values but got {}", this.types.len(), len),
188            )))));
189        }
190
191        let mut ranges = vec![];
192        for _ in 0..len {
193            check_remaining(&chunk, 4)?;
194            let len = chunk.get_i32();
195            if len == -1 {
196                ranges.push(None);
197            } else {
198                let len = len as usize;
199                check_remaining(&chunk, len)?;
200                let start = chunk.position() as usize;
201                ranges.push(Some(start..start + len));
202                chunk.advance(len);
203            }
204        }
205
206        Poll::Ready(Some(Ok(BinaryCopyOutRow {
207            buf: chunk.into_inner(),
208            ranges,
209            types: this.types.clone(),
210        })))
211    }
212}
213
214fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
215    if buf.remaining() < len {
216        Err(Error::parse(io::Error::new(
217            io::ErrorKind::UnexpectedEof,
218            "unexpected EOF",
219        )))
220    } else {
221        Ok(())
222    }
223}
224
225/// A row of data parsed from a binary copy out stream.
226pub struct BinaryCopyOutRow {
227    buf: Bytes,
228    ranges: Vec<Option<Range<usize>>>,
229    types: Arc<Vec<Type>>,
230}
231
232impl BinaryCopyOutRow {
233    /// Like `get`, but returns a `Result` rather than panicking.
234    pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
235    where
236        T: FromSql<'a>,
237    {
238        let type_ = match self.types.get(idx) {
239            Some(type_) => type_,
240            None => return Err(Error::column(idx.to_string())),
241        };
242
243        if !T::accepts(type_) {
244            return Err(Error::from_sql(
245                Box::new(WrongType::new::<T>(type_.clone())),
246                idx,
247            ));
248        }
249
250        let r = match &self.ranges[idx] {
251            Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
252            None => T::from_sql_null(type_),
253        };
254
255        r.map_err(|e| Error::from_sql(e, idx))
256    }
257
258    /// Deserializes a value from the row.
259    ///
260    /// # Panics
261    ///
262    /// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
263    pub fn get<'a, T>(&'a self, idx: usize) -> T
264    where
265        T: FromSql<'a>,
266    {
267        match self.try_get(idx) {
268            Ok(value) => value,
269            Err(e) => panic!("error retrieving column {}: {}", idx, e),
270        }
271    }
272}