Skip to main content

xitca_postgres/
copy.rs

1use core::future::Future;
2
3use xitca_io::bytes::{Buf, Bytes, BytesMut};
4
5use super::{
6    client::{Client, ClientBorrow, ClientBorrowMut},
7    driver::codec::Response,
8    error::Error,
9    iter::AsyncLendingIterator,
10    protocol::message::{backend, frontend},
11    statement::Statement,
12};
13
14pub trait r#Copy: ClientBorrowMut {
15    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
16    where
17        F: FnOnce(&mut BytesMut) -> Result<(), Error>;
18}
19
20impl r#Copy for Client {
21    #[inline]
22    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
23    where
24        F: FnOnce(&mut BytesMut) -> Result<(), Error>,
25    {
26        self.tx.send_one_way(func)
27    }
28}
29
30pub struct CopyIn<'a, C>
31where
32    C: r#Copy + Send,
33{
34    client: &'a mut C,
35    res: Option<Response>,
36}
37
38impl<C> Drop for CopyIn<'_, C>
39where
40    C: r#Copy + Send,
41{
42    fn drop(&mut self) {
43        // when response is not taken on drop it means the progress is aborted before finish.
44        // cancel the copy in this case
45        if self.res.is_some() {
46            self.do_cancel();
47        }
48    }
49}
50
51impl<'a, C> CopyIn<'a, C>
52where
53    C: r#Copy + Send,
54{
55    pub fn new(client: &'a mut C, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
56        // marker check to ensure exclusive borrowing Client. see ClientBorrowMut for detail
57        let res = client.borrow_cli_mut().query_raw(stmt.bind_none()).map(|(_, res)| res);
58
59        async {
60            let mut res = res?;
61            match res.recv().await? {
62                backend::Message::BindComplete => {}
63                _ => return Err(Error::unexpected()),
64            }
65
66            match res.recv().await? {
67                backend::Message::CopyInResponse(_) => {}
68                _ => return Err(Error::unexpected()),
69            }
70
71            Ok(CopyIn { client, res: Some(res) })
72        }
73    }
74
75    /// copy given buffer into [`Driver`] and send it to database in non blocking manner
76    ///
77    /// *. calling this api in rapid succession and/or supply huge buffer may result in high memory consumption.
78    /// consider rate limiting the progress with small chunk of buffer and/or using smart pointers for throughput
79    /// counting
80    ///
81    /// [`Driver`]: crate::driver::Driver
82    pub fn copy(&mut self, item: impl Buf) -> Result<(), Error> {
83        let data = frontend::CopyData::new(item)?;
84        self.client.send_one_way(|buf| {
85            data.write(buf);
86            Ok(())
87        })
88    }
89
90    /// finish copy in and return how many rows are affected
91    pub async fn finish(mut self) -> Result<u64, Error> {
92        self.client.send_one_way(|buf| {
93            frontend::copy_done(buf);
94            frontend::sync(buf);
95            Ok(())
96        })?;
97        self.res.take().unwrap().try_into_row_affected().await
98    }
99
100    fn do_cancel(&mut self) {
101        let _ = self.client.send_one_way(|buf| {
102            frontend::copy_fail("", buf)?;
103            frontend::sync(buf);
104            Ok(())
105        });
106    }
107}
108
109pub struct CopyOut {
110    res: Response,
111}
112
113impl CopyOut {
114    pub fn new(cli: &impl ClientBorrow, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
115        let res = cli.borrow_cli_ref().query_raw(stmt.bind_none()).map(|(_, res)| res);
116
117        async {
118            let mut res = res?;
119
120            match res.recv().await? {
121                backend::Message::BindComplete => {}
122                _ => return Err(Error::unexpected()),
123            }
124
125            match res.recv().await? {
126                backend::Message::CopyOutResponse(_) => {}
127                _ => return Err(Error::unexpected()),
128            }
129
130            Ok(CopyOut { res })
131        }
132    }
133}
134
135impl AsyncLendingIterator for CopyOut {
136    type Ok<'i>
137        = Bytes
138    where
139        Self: 'i;
140    type Err = Error;
141
142    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
143        match self.res.recv().await? {
144            backend::Message::CopyData(body) => Ok(Some(body.into_bytes())),
145            backend::Message::CopyDone => Ok(None),
146            _ => Err(Error::unexpected()),
147        }
148    }
149}