tokio_postgres/
copy_in.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::query::extract_row_affected;
5use crate::{query, slice_iter, Error, Statement};
6use bytes::{Buf, BufMut, BytesMut};
7use futures_channel::mpsc;
8use futures_util::{Sink, SinkExt, Stream, StreamExt};
9use log::debug;
10use pin_project_lite::pin_project;
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use postgres_protocol::message::frontend::CopyData;
14use std::future;
15use std::marker::PhantomData;
16use std::pin::Pin;
17use std::task::{ready, Context, Poll};
18
19enum CopyInMessage {
20    Message(FrontendMessage),
21    Done,
22}
23
24pub struct CopyInReceiver {
25    receiver: mpsc::Receiver<CopyInMessage>,
26    done: bool,
27}
28
29impl CopyInReceiver {
30    fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
31        CopyInReceiver {
32            receiver,
33            done: false,
34        }
35    }
36}
37
38impl Stream for CopyInReceiver {
39    type Item = FrontendMessage;
40
41    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
42        if self.done {
43            return Poll::Ready(None);
44        }
45
46        match ready!(self.receiver.poll_next_unpin(cx)) {
47            Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
48            Some(CopyInMessage::Done) => {
49                self.done = true;
50                let mut buf = BytesMut::new();
51                frontend::copy_done(&mut buf);
52                frontend::sync(&mut buf);
53                Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
54            }
55            None => {
56                self.done = true;
57                let mut buf = BytesMut::new();
58                frontend::copy_fail("", &mut buf).unwrap();
59                frontend::sync(&mut buf);
60                Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
61            }
62        }
63    }
64}
65
66enum SinkState {
67    Active,
68    Closing,
69    Reading,
70}
71
72pin_project! {
73    /// A sink for `COPY ... FROM STDIN` query data.
74    ///
75    /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
76    /// not, the copy will be aborted.
77    #[project(!Unpin)]
78    pub struct CopyInSink<T> {
79        #[pin]
80        sender: mpsc::Sender<CopyInMessage>,
81        responses: Responses,
82        buf: BytesMut,
83        state: SinkState,
84        _p2: PhantomData<T>,
85    }
86}
87
88impl<T> CopyInSink<T>
89where
90    T: Buf + 'static + Send,
91{
92    /// A poll-based version of `finish`.
93    pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
94        loop {
95            match self.state {
96                SinkState::Active => {
97                    ready!(self.as_mut().poll_flush(cx))?;
98                    let mut this = self.as_mut().project();
99                    ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
100                    this.sender
101                        .start_send(CopyInMessage::Done)
102                        .map_err(|_| Error::closed())?;
103                    *this.state = SinkState::Closing;
104                }
105                SinkState::Closing => {
106                    let this = self.as_mut().project();
107                    ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
108                    *this.state = SinkState::Reading;
109                }
110                SinkState::Reading => {
111                    let this = self.as_mut().project();
112                    match ready!(this.responses.poll_next(cx))? {
113                        Message::CommandComplete(body) => {
114                            let rows = extract_row_affected(&body)?;
115                            return Poll::Ready(Ok(rows));
116                        }
117                        _ => return Poll::Ready(Err(Error::unexpected_message())),
118                    }
119                }
120            }
121        }
122    }
123
124    /// Completes the copy, returning the number of rows inserted.
125    ///
126    /// The `Sink::close` method is equivalent to `finish`, except that it does not return the
127    /// number of rows.
128    pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
129        future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
130    }
131}
132
133impl<T> Sink<T> for CopyInSink<T>
134where
135    T: Buf + 'static + Send,
136{
137    type Error = Error;
138
139    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
140        self.project()
141            .sender
142            .poll_ready(cx)
143            .map_err(|_| Error::closed())
144    }
145
146    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
147        let this = self.project();
148
149        let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
150            if this.buf.is_empty() {
151                Box::new(item)
152            } else {
153                Box::new(this.buf.split().freeze().chain(item))
154            }
155        } else {
156            this.buf.put(item);
157            if this.buf.len() > 4096 {
158                Box::new(this.buf.split().freeze())
159            } else {
160                return Ok(());
161            }
162        };
163
164        let data = CopyData::new(data).map_err(Error::encode)?;
165        this.sender
166            .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
167            .map_err(|_| Error::closed())
168    }
169
170    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
171        let mut this = self.project();
172
173        if !this.buf.is_empty() {
174            ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
175            let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
176            let data = CopyData::new(data).map_err(Error::encode)?;
177            this.sender
178                .as_mut()
179                .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
180                .map_err(|_| Error::closed())?;
181        }
182
183        this.sender.poll_flush(cx).map_err(|_| Error::closed())
184    }
185
186    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
187        self.poll_finish(cx).map_ok(|_| ())
188    }
189}
190
191pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error>
192where
193    T: Buf + 'static + Send,
194{
195    debug!("executing copy in statement {}", statement.name());
196
197    let buf = query::encode(client, &statement, slice_iter(&[]))?;
198
199    let (mut sender, receiver) = mpsc::channel(1);
200    let receiver = CopyInReceiver::new(receiver);
201    let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
202
203    sender
204        .send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
205        .await
206        .map_err(|_| Error::closed())?;
207
208    match responses.next().await? {
209        Message::BindComplete => {}
210        _ => return Err(Error::unexpected_message()),
211    }
212
213    match responses.next().await? {
214        Message::CopyInResponse(_) => {}
215        _ => return Err(Error::unexpected_message()),
216    }
217
218    Ok(CopyInSink {
219        sender,
220        responses,
221        buf: BytesMut::new(),
222        state: SinkState::Active,
223        _p2: PhantomData,
224    })
225}