xitca_postgres/
copy.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
use core::future::Future;

use postgres_protocol::message::{backend, frontend};
use xitca_io::bytes::{Buf, Bytes, BytesMut};

use super::{
    client::ClientBorrowMut, driver::codec::Response, error::Error, iter::AsyncLendingIterator, query::Query,
    statement::Statement,
};

pub trait r#Copy: Query + ClientBorrowMut {
    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
    where
        F: FnOnce(&mut BytesMut) -> Result<(), Error>;
}

pub struct CopyIn<'a, C>
where
    C: r#Copy + Send,
{
    client: &'a mut C,
    res: Option<Response>,
}

impl<C> Drop for CopyIn<'_, C>
where
    C: r#Copy + Send,
{
    fn drop(&mut self) {
        // when response is not taken on drop it means the progress is aborted before finish.
        // cancel the copy in this case
        if self.res.is_some() {
            self.do_cancel();
        }
    }
}

impl<'a, C> CopyIn<'a, C>
where
    C: r#Copy + Send,
{
    pub fn new(client: &'a mut C, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
        // marker check to ensure exclusive borrowing Client. see ClientBorrowMut for detail
        let _cli = client._borrow_mut();

        let res = client._send_encode_query(stmt).map(|(_, res)| res);

        async {
            let mut res = res?;
            match res.recv().await? {
                backend::Message::BindComplete => {}
                _ => return Err(Error::unexpected()),
            }

            match res.recv().await? {
                backend::Message::CopyInResponse(_) => {}
                _ => return Err(Error::unexpected()),
            }

            Ok(CopyIn { client, res: Some(res) })
        }
    }

    /// copy given buffer into [`Driver`] and send it to database in non blocking manner
    ///
    /// *. calling this api in rapid succession and/or supply huge buffer may result in high memory consumption.
    /// consider rate limiting the progress with small chunk of buffer and/or using smart pointers for throughput
    /// counting
    ///
    /// [`Driver`]: crate::driver::Driver
    pub fn copy(&mut self, item: impl Buf) -> Result<(), Error> {
        let data = frontend::CopyData::new(item)?;
        self.client.send_one_way(|buf| {
            data.write(buf);
            Ok(())
        })
    }

    /// finish copy in and return how many rows are affected
    pub async fn finish(mut self) -> Result<u64, Error> {
        self.client.send_one_way(|buf| {
            frontend::copy_done(buf);
            frontend::sync(buf);
            Ok(())
        })?;
        self.res.take().unwrap().try_into_row_affected().await
    }

    fn do_cancel(&mut self) {
        let _ = self.client.send_one_way(|buf| {
            frontend::copy_fail("", buf)?;
            frontend::sync(buf);
            Ok(())
        });
    }
}

pub struct CopyOut {
    res: Response,
}

impl CopyOut {
    pub fn new(cli: &impl Query, stmt: &Statement) -> impl Future<Output = Result<Self, Error>> + Send {
        let res = cli._send_encode_query(stmt).map(|(_, res)| res);

        async {
            let mut res = res?;

            match res.recv().await? {
                backend::Message::BindComplete => {}
                _ => return Err(Error::unexpected()),
            }

            match res.recv().await? {
                backend::Message::CopyOutResponse(_) => {}
                _ => return Err(Error::unexpected()),
            }

            Ok(CopyOut { res })
        }
    }
}

impl AsyncLendingIterator for CopyOut {
    type Ok<'i>
        = Bytes
    where
        Self: 'i;
    type Err = Error;

    async fn try_next(&mut self) -> Result<Option<Self::Ok<'_>>, Self::Err> {
        match self.res.recv().await? {
            backend::Message::CopyData(body) => Ok(Some(body.into_bytes())),
            backend::Message::CopyDone => Ok(None),
            _ => Err(Error::unexpected()),
        }
    }
}