xitca_postgres/
copy.rs

1use core::future::Future;
2
3use postgres_protocol::message::{backend, frontend};
4use xitca_io::bytes::{Buf, Bytes, BytesMut};
5
6use super::{
7    client::ClientBorrowMut, driver::codec::Response, error::Error, iter::AsyncLendingIterator, query::Query,
8    statement::Statement,
9};
10
11pub trait r#Copy: Query + ClientBorrowMut {
12    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
13    where
14        F: FnOnce(&mut BytesMut) -> Result<(), Error>;
15}
16
17pub struct CopyIn<'a, C>
18where
19    C: r#Copy + Send,
20{
21    client: &'a mut C,
22    res: Option<Response>,
23}
24
25impl<C> Drop for CopyIn<'_, C>
26where
27    C: r#Copy + Send,
28{
29    fn drop(&mut self) {
30        // when response is not taken on drop it means the progress is aborted before finish.
31        // cancel the copy in this case
32        if self.res.is_some() {
33            self.do_cancel();
34        }
35    }
36}
37
38impl<'a, C> CopyIn<'a, C>
39where
40    C: r#Copy + Send,
41{
42    pub fn new(client: &'a mut C, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
43        // marker check to ensure exclusive borrowing Client. see ClientBorrowMut for detail
44        let _cli = client._borrow_mut();
45
46        let res = client._send_encode_query(stmt).map(|(_, res)| res);
47
48        async {
49            let mut res = res?;
50            match res.recv().await? {
51                backend::Message::BindComplete => {}
52                _ => return Err(Error::unexpected()),
53            }
54
55            match res.recv().await? {
56                backend::Message::CopyInResponse(_) => {}
57                _ => return Err(Error::unexpected()),
58            }
59
60            Ok(CopyIn { client, res: Some(res) })
61        }
62    }
63
64    /// copy given buffer into [`Driver`] and send it to database in non blocking manner
65    ///
66    /// *. calling this api in rapid succession and/or supply huge buffer may result in high memory consumption.
67    /// consider rate limiting the progress with small chunk of buffer and/or using smart pointers for throughput
68    /// counting
69    ///
70    /// [`Driver`]: crate::driver::Driver
71    pub fn copy(&mut self, item: impl Buf) -> Result<(), Error> {
72        let data = frontend::CopyData::new(item)?;
73        self.client.send_one_way(|buf| {
74            data.write(buf);
75            Ok(())
76        })
77    }
78
79    /// finish copy in and return how many rows are affected
80    pub async fn finish(mut self) -> Result<u64, Error> {
81        self.client.send_one_way(|buf| {
82            frontend::copy_done(buf);
83            frontend::sync(buf);
84            Ok(())
85        })?;
86        self.res.take().unwrap().try_into_row_affected().await
87    }
88
89    fn do_cancel(&mut self) {
90        let _ = self.client.send_one_way(|buf| {
91            frontend::copy_fail("", buf)?;
92            frontend::sync(buf);
93            Ok(())
94        });
95    }
96}
97
98pub struct CopyOut {
99    res: Response,
100}
101
102impl CopyOut {
103    pub fn new(cli: &impl Query, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
104        let res = cli._send_encode_query(stmt).map(|(_, res)| res);
105
106        async {
107            let mut res = res?;
108
109            match res.recv().await? {
110                backend::Message::BindComplete => {}
111                _ => return Err(Error::unexpected()),
112            }
113
114            match res.recv().await? {
115                backend::Message::CopyOutResponse(_) => {}
116                _ => return Err(Error::unexpected()),
117            }
118
119            Ok(CopyOut { res })
120        }
121    }
122}
123
124impl AsyncLendingIterator for CopyOut {
125    type Ok<'i>
126        = Bytes
127    where
128        Self: 'i;
129    type Err = Error;
130
131    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
132        match self.res.recv().await? {
133            backend::Message::CopyData(body) => Ok(Some(body.into_bytes())),
134            backend::Message::CopyDone => Ok(None),
135            _ => Err(Error::unexpected()),
136        }
137    }
138}