xitca_postgres/
copy.rs

1use core::future::Future;
2
3use postgres_protocol::message::{backend, frontend};
4use xitca_io::bytes::{Buf, Bytes, BytesMut};
5
6use super::{
7    client::{Client, ClientBorrowMut},
8    driver::codec::Response,
9    error::Error,
10    iter::AsyncLendingIterator,
11    query::Query,
12    statement::Statement,
13};
14
15pub trait r#Copy: Query + ClientBorrowMut {
16    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
17    where
18        F: FnOnce(&mut BytesMut) -> Result<(), Error>;
19}
20
21impl r#Copy for Client {
22    #[inline]
23    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
24    where
25        F: FnOnce(&mut BytesMut) -> Result<(), Error>,
26    {
27        self.tx.send_one_way(func)
28    }
29}
30
31pub struct CopyIn<'a, C>
32where
33    C: r#Copy + Send,
34{
35    client: &'a mut C,
36    res: Option<Response>,
37}
38
39impl<C> Drop for CopyIn<'_, C>
40where
41    C: r#Copy + Send,
42{
43    fn drop(&mut self) {
44        // when response is not taken on drop it means the progress is aborted before finish.
45        // cancel the copy in this case
46        if self.res.is_some() {
47            self.do_cancel();
48        }
49    }
50}
51
52impl<'a, C> CopyIn<'a, C>
53where
54    C: r#Copy + Send,
55{
56    pub fn new(client: &'a mut C, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
57        // marker check to ensure exclusive borrowing Client. see ClientBorrowMut for detail
58        let _cli = client._borrow_mut();
59
60        let res = client._send_encode_query(stmt.bind_none()).map(|(_, res)| res);
61
62        async {
63            let mut res = res?;
64            match res.recv().await? {
65                backend::Message::BindComplete => {}
66                _ => return Err(Error::unexpected()),
67            }
68
69            match res.recv().await? {
70                backend::Message::CopyInResponse(_) => {}
71                _ => return Err(Error::unexpected()),
72            }
73
74            Ok(CopyIn { client, res: Some(res) })
75        }
76    }
77
78    /// copy given buffer into [`Driver`] and send it to database in non blocking manner
79    ///
80    /// *. calling this api in rapid succession and/or supply huge buffer may result in high memory consumption.
81    /// consider rate limiting the progress with small chunk of buffer and/or using smart pointers for throughput
82    /// counting
83    ///
84    /// [`Driver`]: crate::driver::Driver
85    pub fn copy(&mut self, item: impl Buf) -> Result<(), Error> {
86        let data = frontend::CopyData::new(item)?;
87        self.client.send_one_way(|buf| {
88            data.write(buf);
89            Ok(())
90        })
91    }
92
93    /// finish copy in and return how many rows are affected
94    pub async fn finish(mut self) -> Result<u64, Error> {
95        self.client.send_one_way(|buf| {
96            frontend::copy_done(buf);
97            frontend::sync(buf);
98            Ok(())
99        })?;
100        self.res.take().unwrap().try_into_row_affected().await
101    }
102
103    fn do_cancel(&mut self) {
104        let _ = self.client.send_one_way(|buf| {
105            frontend::copy_fail("", buf)?;
106            frontend::sync(buf);
107            Ok(())
108        });
109    }
110}
111
112pub struct CopyOut {
113    res: Response,
114}
115
116impl CopyOut {
117    pub fn new(cli: &impl Query, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
118        let res = cli._send_encode_query(stmt.bind_none()).map(|(_, res)| res);
119
120        async {
121            let mut res = res?;
122
123            match res.recv().await? {
124                backend::Message::BindComplete => {}
125                _ => return Err(Error::unexpected()),
126            }
127
128            match res.recv().await? {
129                backend::Message::CopyOutResponse(_) => {}
130                _ => return Err(Error::unexpected()),
131            }
132
133            Ok(CopyOut { res })
134        }
135    }
136}
137
138impl AsyncLendingIterator for CopyOut {
139    type Ok<'i>
140        = Bytes
141    where
142        Self: 'i;
143    type Err = Error;
144
145    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
146        match self.res.recv().await? {
147            backend::Message::CopyData(body) => Ok(Some(body.into_bytes())),
148            backend::Message::CopyDone => Ok(None),
149            _ => Err(Error::unexpected()),
150        }
151    }
152}