tokio_postgres/
copy_in.rs1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::query::extract_row_affected;
5use crate::{query, slice_iter, Error, Statement};
6use bytes::{Buf, BufMut, BytesMut};
7use futures_channel::mpsc;
8use futures_util::{Sink, SinkExt, Stream, StreamExt};
9use log::debug;
10use pin_project_lite::pin_project;
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use postgres_protocol::message::frontend::CopyData;
14use std::future;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{ready, Context, Poll};
18
19enum CopyInMessage {
20 Message(FrontendMessage),
21 Done,
22}
23
24pub struct CopyInReceiver {
25 receiver: mpsc::Receiver<CopyInMessage>,
26 done: bool,
27}
28
29impl CopyInReceiver {
30 fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
31 CopyInReceiver {
32 receiver,
33 done: false,
34 }
35 }
36}
37
38impl Stream for CopyInReceiver {
39 type Item = FrontendMessage;
40
41 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
42 if self.done {
43 return Poll::Ready(None);
44 }
45
46 match ready!(self.receiver.poll_next_unpin(cx)) {
47 Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
48 Some(CopyInMessage::Done) => {
49 self.done = true;
50 let mut buf = BytesMut::new();
51 frontend::copy_done(&mut buf);
52 frontend::sync(&mut buf);
53 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
54 }
55 None => {
56 self.done = true;
57 let mut buf = BytesMut::new();
58 frontend::copy_fail("", &mut buf).unwrap();
59 frontend::sync(&mut buf);
60 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
61 }
62 }
63 }
64}
65
66enum SinkState {
67 Active,
68 Closing,
69 Reading,
70}
71
72pin_project! {
73 #[project(!Unpin)]
78 pub struct CopyInSink<T> {
79 #[pin]
80 sender: mpsc::Sender<CopyInMessage>,
81 responses: Responses,
82 buf: BytesMut,
83 state: SinkState,
84 _p2: PhantomData<T>,
85 }
86}
87
88impl<T> CopyInSink<T>
89where
90 T: Buf + 'static + Send,
91{
92 pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
94 loop {
95 match self.state {
96 SinkState::Active => {
97 ready!(self.as_mut().poll_flush(cx))?;
98 let mut this = self.as_mut().project();
99 ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
100 this.sender
101 .start_send(CopyInMessage::Done)
102 .map_err(|_| Error::closed())?;
103 *this.state = SinkState::Closing;
104 }
105 SinkState::Closing => {
106 let this = self.as_mut().project();
107 ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
108 *this.state = SinkState::Reading;
109 }
110 SinkState::Reading => {
111 let this = self.as_mut().project();
112 match ready!(this.responses.poll_next(cx))? {
113 Message::CommandComplete(body) => {
114 let rows = extract_row_affected(&body)?;
115 return Poll::Ready(Ok(rows));
116 }
117 _ => return Poll::Ready(Err(Error::unexpected_message())),
118 }
119 }
120 }
121 }
122 }
123
124 pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
129 future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
130 }
131}
132
133impl<T> Sink<T> for CopyInSink<T>
134where
135 T: Buf + 'static + Send,
136{
137 type Error = Error;
138
139 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
140 self.project()
141 .sender
142 .poll_ready(cx)
143 .map_err(|_| Error::closed())
144 }
145
146 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
147 let this = self.project();
148
149 let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
150 if this.buf.is_empty() {
151 Box::new(item)
152 } else {
153 Box::new(this.buf.split().freeze().chain(item))
154 }
155 } else {
156 this.buf.put(item);
157 if this.buf.len() > 4096 {
158 Box::new(this.buf.split().freeze())
159 } else {
160 return Ok(());
161 }
162 };
163
164 let data = CopyData::new(data).map_err(Error::encode)?;
165 this.sender
166 .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
167 .map_err(|_| Error::closed())
168 }
169
170 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
171 let mut this = self.project();
172
173 if !this.buf.is_empty() {
174 ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
175 let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
176 let data = CopyData::new(data).map_err(Error::encode)?;
177 this.sender
178 .as_mut()
179 .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
180 .map_err(|_| Error::closed())?;
181 }
182
183 this.sender.poll_flush(cx).map_err(|_| Error::closed())
184 }
185
186 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
187 self.poll_finish(cx).map_ok(|_| ())
188 }
189}
190
191pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error>
192where
193 T: Buf + 'static + Send,
194{
195 debug!("executing copy in statement {}", statement.name());
196
197 let buf = query::encode(client, &statement, slice_iter(&[]))?;
198
199 let (mut sender, receiver) = mpsc::channel(1);
200 let receiver = CopyInReceiver::new(receiver);
201 let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
202
203 sender
204 .send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
205 .await
206 .map_err(|_| Error::closed())?;
207
208 match responses.next().await? {
209 Message::BindComplete => {}
210 _ => return Err(Error::unexpected_message()),
211 }
212
213 match responses.next().await? {
214 Message::CopyInResponse(_) => {}
215 _ => return Err(Error::unexpected_message()),
216 }
217
218 Ok(CopyInSink {
219 sender,
220 responses,
221 buf: BytesMut::new(),
222 state: SinkState::Active,
223 _p2: PhantomData,
224 })
225}