sqlx_core_guts/io/
buf_stream.rs

1#![allow(dead_code)]
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5
6use bytes::BytesMut;
7use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWrite};
8
9use crate::error::Error;
10use crate::io::write_and_flush::WriteAndFlush;
11use crate::io::{decode::Decode, encode::Encode};
12use std::io::Cursor;
13
14pub struct BufStream<S>
15where
16    S: AsyncRead + AsyncWrite + Unpin,
17{
18    pub(crate) stream: S,
19
20    // writes with `write` to the underlying stream are buffered
21    // this can be flushed with `flush`
22    pub(crate) wbuf: Vec<u8>,
23
24    // we read into the read buffer using 100% safe code
25    rbuf: BytesMut,
26}
27
28impl<S> BufStream<S>
29where
30    S: AsyncRead + AsyncWrite + Unpin,
31{
32    pub fn new(stream: S) -> Self {
33        Self {
34            stream,
35            wbuf: Vec::with_capacity(512),
36            rbuf: BytesMut::with_capacity(4096),
37        }
38    }
39
40    pub fn write<'en, T>(&mut self, value: T)
41    where
42        T: Encode<'en, ()>,
43    {
44        self.write_with(value, ())
45    }
46
47    pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
48    where
49        T: Encode<'en, C>,
50    {
51        value.encode_with(&mut self.wbuf, context);
52    }
53
54    pub fn flush(&mut self) -> WriteAndFlush<'_, S> {
55        WriteAndFlush {
56            stream: &mut self.stream,
57            buf: Cursor::new(&mut self.wbuf),
58        }
59    }
60
61    pub async fn read<'de, T>(&mut self, cnt: usize) -> Result<T, Error>
62    where
63        T: Decode<'de, ()>,
64    {
65        self.read_with(cnt, ()).await
66    }
67
68    pub async fn read_with<'de, T, C>(&mut self, cnt: usize, context: C) -> Result<T, Error>
69    where
70        T: Decode<'de, C>,
71    {
72        T::decode_with(self.read_raw(cnt).await?.freeze(), context)
73    }
74
75    pub async fn read_raw(&mut self, cnt: usize) -> Result<BytesMut, Error> {
76        read_raw_into(&mut self.stream, &mut self.rbuf, cnt).await?;
77        let buf = self.rbuf.split_to(cnt);
78
79        Ok(buf)
80    }
81
82    pub async fn read_raw_into(&mut self, buf: &mut BytesMut, cnt: usize) -> Result<(), Error> {
83        read_raw_into(&mut self.stream, buf, cnt).await
84    }
85}
86
87impl<S> Deref for BufStream<S>
88where
89    S: AsyncRead + AsyncWrite + Unpin,
90{
91    type Target = S;
92
93    fn deref(&self) -> &Self::Target {
94        &self.stream
95    }
96}
97
98impl<S> DerefMut for BufStream<S>
99where
100    S: AsyncRead + AsyncWrite + Unpin,
101{
102    fn deref_mut(&mut self) -> &mut Self::Target {
103        &mut self.stream
104    }
105}
106
107// Holds a buffer which has been temporarily extended, so that
108// we can read into it. Automatically shrinks the buffer back
109// down if the read is cancelled.
110struct BufTruncator<'a> {
111    buf: &'a mut BytesMut,
112    filled_len: usize,
113}
114
115impl<'a> BufTruncator<'a> {
116    fn new(buf: &'a mut BytesMut) -> Self {
117        let filled_len = buf.len();
118        Self { buf, filled_len }
119    }
120    fn reserve(&mut self, space: usize) {
121        self.buf.resize(self.filled_len + space, 0);
122    }
123    async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
124        let n = stream.read(&mut self.buf[self.filled_len..]).await?;
125        self.filled_len += n;
126        Ok(n)
127    }
128    fn is_full(&self) -> bool {
129        self.filled_len >= self.buf.len()
130    }
131}
132
133impl Drop for BufTruncator<'_> {
134    fn drop(&mut self) {
135        self.buf.truncate(self.filled_len);
136    }
137}
138
139async fn read_raw_into<S: AsyncRead + Unpin>(
140    stream: &mut S,
141    buf: &mut BytesMut,
142    cnt: usize,
143) -> Result<(), Error> {
144    let mut buf = BufTruncator::new(buf);
145    buf.reserve(cnt);
146
147    while !buf.is_full() {
148        let n = buf.read(stream).await?;
149
150        if n == 0 {
151            // a zero read when we had space in the read buffer
152            // should be treated as an EOF
153
154            // and an unexpected EOF means the server told us to go away
155
156            return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
157        }
158    }
159
160    Ok(())
161}