xitca_postgres/driver/
codec.rs1pub(crate) mod encode;
2pub(crate) mod response;
3
4use core::{
5 future::{poll_fn, Future},
6 task::{ready, Context, Poll},
7};
8
9use postgres_protocol::message::backend;
10use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
11use xitca_io::bytes::BytesMut;
12
13use crate::{
14 error::{Completed, DriverDownReceiving, Error},
15 types::BorrowToSql,
16};
17
18use super::DriverTx;
19
20pub(super) fn request_pair(msg_count: usize) -> (ResponseSender, Response) {
21 let (tx, rx) = unbounded_channel();
22 (
23 ResponseSender { tx, msg_count },
24 Response {
25 rx,
26 buf: BytesMut::new(),
27 complete: false,
28 },
29 )
30}
31
32#[derive(Debug)]
33pub struct Response {
34 rx: ResponseReceiver,
35 buf: BytesMut,
36 complete: bool,
37}
38
39impl Response {
40 pub(crate) fn blocking_recv(&mut self) -> Result<backend::Message, Error> {
41 if self.buf.is_empty() {
42 let res = self.rx.blocking_recv();
43 self.on_recv(res)?;
44 }
45 self.parse_message()
46 }
47
48 pub(crate) fn recv(&mut self) -> impl Future<Output = Result<backend::Message, Error>> + Send + '_ {
49 poll_fn(|cx| self.poll_recv(cx))
50 }
51
52 pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<backend::Message, Error>> {
53 if self.buf.is_empty() {
54 let res = ready!(self.rx.poll_recv(cx));
55 self.on_recv(res)?;
56 }
57 Poll::Ready(self.parse_message())
58 }
59
60 pub(crate) fn try_into_row_affected(mut self) -> impl Future<Output = Result<u64, Error>> + Send {
61 let mut rows = 0;
62 poll_fn(move |cx| {
63 ready!(self.poll_try_into_ready(&mut rows, cx))?;
64 Poll::Ready(Ok(rows))
65 })
66 }
67
68 pub(crate) fn try_into_row_affected_blocking(mut self) -> Result<u64, Error> {
69 let mut rows = 0;
70 loop {
71 match self.blocking_recv()? {
72 backend::Message::BindComplete
73 | backend::Message::NoData
74 | backend::Message::ParseComplete
75 | backend::Message::ParameterDescription(_)
76 | backend::Message::RowDescription(_)
77 | backend::Message::DataRow(_)
78 | backend::Message::EmptyQueryResponse
79 | backend::Message::PortalSuspended => {}
80 backend::Message::CommandComplete(body) => {
81 rows = body_to_affected_rows(&body)?;
82 }
83 backend::Message::ReadyForQuery(_) => return Ok(rows),
84 _ => return Err(Error::unexpected()),
85 }
86 }
87 }
88
89 pub(crate) fn poll_try_into_ready(&mut self, rows: &mut u64, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
90 loop {
91 match ready!(self.poll_recv(cx))? {
92 backend::Message::BindComplete
93 | backend::Message::NoData
94 | backend::Message::ParseComplete
95 | backend::Message::ParameterDescription(_)
96 | backend::Message::RowDescription(_)
97 | backend::Message::DataRow(_)
98 | backend::Message::EmptyQueryResponse
99 | backend::Message::PortalSuspended => {}
100 backend::Message::CommandComplete(body) => {
101 *rows = body_to_affected_rows(&body)?;
102 }
103 backend::Message::ReadyForQuery(_) => return Poll::Ready(Ok(())),
104 _ => return Poll::Ready(Err(Error::unexpected())),
105 }
106 }
107 }
108
109 fn on_recv(&mut self, res: Option<BytesMessage>) -> Result<(), Error> {
110 match res {
111 Some(msg) => {
112 self.complete = msg.complete;
113 self.buf = msg.buf;
114 Ok(())
115 }
116 None => {
117 return if self.complete {
118 Err(Completed.into())
119 } else {
120 Err(DriverDownReceiving.into())
121 }
122 }
123 }
124 }
125
126 fn parse_message(&mut self) -> Result<backend::Message, Error> {
127 match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") {
128 backend::Message::ErrorResponse(body) => Err(Error::db(body.fields())),
129 msg => Ok(msg),
130 }
131 }
132}
133
134pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result<u64, Error> {
136 body.tag()
137 .map_err(|_| Error::todo())
138 .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0))
139}
140
141#[derive(Debug)]
142pub(crate) struct ResponseSender {
143 tx: UnboundedSender<BytesMessage>,
144 msg_count: usize,
145}
146
147pub(super) enum SenderState {
148 Continue,
149 Finish,
150}
151
152impl ResponseSender {
153 pub(super) fn send(&mut self, msg: BytesMessage) -> SenderState {
154 debug_assert!(self.msg_count > 0);
155
156 if msg.complete {
157 self.msg_count -= 1;
158 }
159
160 let _ = self.tx.send(msg);
161
162 if self.msg_count == 0 {
163 SenderState::Finish
164 } else {
165 SenderState::Continue
166 }
167 }
168}
169
170#[allow(dead_code)]
172pub(super) type ResponseReceiver = UnboundedReceiver<BytesMessage>;
173
174pub(super) struct BytesMessage {
175 buf: BytesMut,
176 complete: bool,
177}
178
179impl BytesMessage {
180 #[cold]
181 #[inline(never)]
182 pub(super) fn parse_error(&mut self) -> Error {
183 match backend::Message::parse(&mut self.buf) {
184 Err(e) => Error::from(e),
185 Ok(Some(backend::Message::ErrorResponse(body))) => Error::db(body.fields()),
186 _ => Error::unexpected(),
187 }
188 }
189}
190
191pub(super) enum ResponseMessage {
192 Normal(BytesMessage),
193 Async(backend::Message),
194}
195
196impl ResponseMessage {
197 pub(crate) fn try_from_buf(buf: &mut BytesMut) -> Result<Option<Self>, Error> {
198 let mut tail = 0;
199 let mut complete = false;
200
201 loop {
202 let slice = &buf[tail..];
203 let Some(header) = backend::Header::parse(slice)? else {
204 break;
205 };
206 let len = header.len() as usize + 1;
207
208 if slice.len() < len {
209 break;
210 }
211
212 match header.tag() {
213 backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => {
214 if tail > 0 {
215 break;
216 }
217 let message = backend::Message::parse(buf)?
218 .expect("buffer contains at least one Message. parser must produce Some");
219 return Ok(Some(ResponseMessage::Async(message)));
220 }
221 tag => {
222 tail += len;
223 if matches!(tag, backend::READY_FOR_QUERY_TAG) {
224 complete = true;
225 break;
226 }
227 }
228 }
229 }
230
231 if tail == 0 {
232 Ok(None)
233 } else {
234 Ok(Some(ResponseMessage::Normal(BytesMessage {
235 buf: buf.split_to(tail),
236 complete,
237 })))
238 }
239 }
240}
241
242pub trait AsParams: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql>> {}
245
246impl<I> AsParams for I where I: IntoIterator<IntoIter: ExactSizeIterator<Item: BorrowToSql>> {}
247
248mod sealed {
249 pub trait Sealed {}
250}