1use core::future::Future;
2
3use postgres_protocol::message::{backend, frontend};
4use xitca_io::bytes::{Buf, Bytes, BytesMut};
5
6use super::{
7 client::ClientBorrowMut, driver::codec::Response, error::Error, iter::AsyncLendingIterator, query::Query,
8 statement::Statement,
9};
10
11pub trait r#Copy: Query + ClientBorrowMut {
12 fn send_one_way<F>(&self, func: F) -> Result<(), Error>
13 where
14 F: FnOnce(&mut BytesMut) -> Result<(), Error>;
15}
16
17pub struct CopyIn<'a, C>
18where
19 C: r#Copy + Send,
20{
21 client: &'a mut C,
22 res: Option<Response>,
23}
24
25impl<C> Drop for CopyIn<'_, C>
26where
27 C: r#Copy + Send,
28{
29 fn drop(&mut self) {
30 if self.res.is_some() {
33 self.do_cancel();
34 }
35 }
36}
37
38impl<'a, C> CopyIn<'a, C>
39where
40 C: r#Copy + Send,
41{
42 pub fn new(client: &'a mut C, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
43 let _cli = client._borrow_mut();
45
46 let res = client._send_encode_query(stmt).map(|(_, res)| res);
47
48 async {
49 let mut res = res?;
50 match res.recv().await? {
51 backend::Message::BindComplete => {}
52 _ => return Err(Error::unexpected()),
53 }
54
55 match res.recv().await? {
56 backend::Message::CopyInResponse(_) => {}
57 _ => return Err(Error::unexpected()),
58 }
59
60 Ok(CopyIn { client, res: Some(res) })
61 }
62 }
63
64 pub fn copy(&mut self, item: impl Buf) -> Result<(), Error> {
72 let data = frontend::CopyData::new(item)?;
73 self.client.send_one_way(|buf| {
74 data.write(buf);
75 Ok(())
76 })
77 }
78
79 pub async fn finish(mut self) -> Result<u64, Error> {
81 self.client.send_one_way(|buf| {
82 frontend::copy_done(buf);
83 frontend::sync(buf);
84 Ok(())
85 })?;
86 self.res.take().unwrap().try_into_row_affected().await
87 }
88
89 fn do_cancel(&mut self) {
90 let _ = self.client.send_one_way(|buf| {
91 frontend::copy_fail("", buf)?;
92 frontend::sync(buf);
93 Ok(())
94 });
95 }
96}
97
98pub struct CopyOut {
99 res: Response,
100}
101
102impl CopyOut {
103 pub fn new(cli: &impl Query, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
104 let res = cli._send_encode_query(stmt).map(|(_, res)| res);
105
106 async {
107 let mut res = res?;
108
109 match res.recv().await? {
110 backend::Message::BindComplete => {}
111 _ => return Err(Error::unexpected()),
112 }
113
114 match res.recv().await? {
115 backend::Message::CopyOutResponse(_) => {}
116 _ => return Err(Error::unexpected()),
117 }
118
119 Ok(CopyOut { res })
120 }
121 }
122}
123
124impl AsyncLendingIterator for CopyOut {
125 type Ok<'i>
126 = Bytes
127 where
128 Self: 'i;
129 type Err = Error;
130
131 async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
132 match self.res.recv().await? {
133 backend::Message::CopyData(body) => Ok(Some(body.into_bytes())),
134 backend::Message::CopyDone => Ok(None),
135 _ => Err(Error::unexpected()),
136 }
137 }
138}