1use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
4use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
5use byteorder::{BigEndian, ByteOrder};
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use futures_util::{SinkExt, Stream};
8use pin_project_lite::pin_project;
9use postgres_types::BorrowToSql;
10use std::io;
11use std::io::Cursor;
12use std::ops::Range;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{ready, Context, Poll};
16
17const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
18const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
19
20pin_project! {
21 pub struct BinaryCopyInWriter {
25 #[pin]
26 sink: CopyInSink<Bytes>,
27 types: Vec<Type>,
28 buf: BytesMut,
29 }
30}
31
32impl BinaryCopyInWriter {
33 pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
35 let mut buf = BytesMut::new();
36 buf.put_slice(MAGIC);
37 buf.put_i32(0); buf.put_i32(0); BinaryCopyInWriter {
41 sink,
42 types: types.to_vec(),
43 buf,
44 }
45 }
46
47 pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
53 self.write_raw(slice_iter(values)).await
54 }
55
56 pub async fn write_raw<P, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
62 where
63 P: BorrowToSql,
64 I: IntoIterator<Item = P>,
65 I::IntoIter: ExactSizeIterator,
66 {
67 let mut this = self.project();
68
69 let values = values.into_iter();
70 assert!(
71 values.len() == this.types.len(),
72 "expected {} values but got {}",
73 this.types.len(),
74 values.len(),
75 );
76
77 this.buf.put_i16(this.types.len() as i16);
78
79 for (i, (value, type_)) in values.zip(this.types).enumerate() {
80 let idx = this.buf.len();
81 this.buf.put_i32(0);
82 let len = match value
83 .borrow_to_sql()
84 .to_sql_checked(type_, this.buf)
85 .map_err(|e| Error::to_sql(e, i))?
86 {
87 IsNull::Yes => -1,
88 IsNull::No => i32::try_from(this.buf.len() - idx - 4)
89 .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
90 };
91 BigEndian::write_i32(&mut this.buf[idx..], len);
92 }
93
94 if this.buf.len() > 4096 {
95 this.sink.send(this.buf.split().freeze()).await?;
96 }
97
98 Ok(())
99 }
100
101 pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
105 let mut this = self.project();
106
107 this.buf.put_i16(-1);
108 this.sink.send(this.buf.split().freeze()).await?;
109 this.sink.finish().await
110 }
111}
112
113struct Header {
114 has_oids: bool,
115}
116
117pin_project! {
118 pub struct BinaryCopyOutStream {
120 #[pin]
121 stream: CopyOutStream,
122 types: Arc<Vec<Type>>,
123 header: Option<Header>,
124 }
125}
126
127impl BinaryCopyOutStream {
128 pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
130 BinaryCopyOutStream {
131 stream,
132 types: Arc::new(types.to_vec()),
133 header: None,
134 }
135 }
136}
137
138impl Stream for BinaryCopyOutStream {
139 type Item = Result<BinaryCopyOutRow, Error>;
140
141 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142 let this = self.project();
143
144 let chunk = match ready!(this.stream.poll_next(cx)) {
145 Some(Ok(chunk)) => chunk,
146 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
147 None => return Poll::Ready(Some(Err(Error::closed()))),
148 };
149 let mut chunk = Cursor::new(chunk);
150
151 let has_oids = match &this.header {
152 Some(header) => header.has_oids,
153 None => {
154 check_remaining(&chunk, HEADER_LEN)?;
155 if !chunk.chunk().starts_with(MAGIC) {
156 return Poll::Ready(Some(Err(Error::parse(io::Error::new(
157 io::ErrorKind::InvalidData,
158 "invalid magic value",
159 )))));
160 }
161 chunk.advance(MAGIC.len());
162
163 let flags = chunk.get_i32();
164 let has_oids = (flags & (1 << 16)) != 0;
165
166 let header_extension = chunk.get_u32() as usize;
167 check_remaining(&chunk, header_extension)?;
168 chunk.advance(header_extension);
169
170 *this.header = Some(Header { has_oids });
171 has_oids
172 }
173 };
174
175 check_remaining(&chunk, 2)?;
176 let mut len = chunk.get_i16();
177 if len == -1 {
178 return Poll::Ready(None);
179 }
180
181 if has_oids {
182 len += 1;
183 }
184 if len as usize != this.types.len() {
185 return Poll::Ready(Some(Err(Error::parse(io::Error::new(
186 io::ErrorKind::InvalidInput,
187 format!("expected {} values but got {}", this.types.len(), len),
188 )))));
189 }
190
191 let mut ranges = vec![];
192 for _ in 0..len {
193 check_remaining(&chunk, 4)?;
194 let len = chunk.get_i32();
195 if len == -1 {
196 ranges.push(None);
197 } else {
198 let len = len as usize;
199 check_remaining(&chunk, len)?;
200 let start = chunk.position() as usize;
201 ranges.push(Some(start..start + len));
202 chunk.advance(len);
203 }
204 }
205
206 Poll::Ready(Some(Ok(BinaryCopyOutRow {
207 buf: chunk.into_inner(),
208 ranges,
209 types: this.types.clone(),
210 })))
211 }
212}
213
214fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
215 if buf.remaining() < len {
216 Err(Error::parse(io::Error::new(
217 io::ErrorKind::UnexpectedEof,
218 "unexpected EOF",
219 )))
220 } else {
221 Ok(())
222 }
223}
224
225pub struct BinaryCopyOutRow {
227 buf: Bytes,
228 ranges: Vec<Option<Range<usize>>>,
229 types: Arc<Vec<Type>>,
230}
231
232impl BinaryCopyOutRow {
233 pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
235 where
236 T: FromSql<'a>,
237 {
238 let type_ = match self.types.get(idx) {
239 Some(type_) => type_,
240 None => return Err(Error::column(idx.to_string())),
241 };
242
243 if !T::accepts(type_) {
244 return Err(Error::from_sql(
245 Box::new(WrongType::new::<T>(type_.clone())),
246 idx,
247 ));
248 }
249
250 let r = match &self.ranges[idx] {
251 Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
252 None => T::from_sql_null(type_),
253 };
254
255 r.map_err(|e| Error::from_sql(e, idx))
256 }
257
258 pub fn get<'a, T>(&'a self, idx: usize) -> T
264 where
265 T: FromSql<'a>,
266 {
267 match self.try_get(idx) {
268 Ok(value) => value,
269 Err(e) => panic!("error retrieving column {}: {}", idx, e),
270 }
271 }
272}