sqlx_build_trust_core/net/socket/
buffered.rs

1use crate::net::Socket;
2use bytes::BytesMut;
3use std::{cmp, io};
4
5use crate::error::Error;
6
7use crate::io::{Decode, Encode};
8
9// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
10const DEFAULT_BUF_SIZE: usize = 8192;
11
12pub struct BufferedSocket<S> {
13    socket: S,
14    write_buf: WriteBuffer,
15    read_buf: ReadBuffer,
16}
17
18pub struct WriteBuffer {
19    buf: Vec<u8>,
20    bytes_written: usize,
21    bytes_flushed: usize,
22}
23
24pub struct ReadBuffer {
25    read: BytesMut,
26    available: BytesMut,
27}
28
29impl<S: Socket> BufferedSocket<S> {
30    pub fn new(socket: S) -> Self
31    where
32        S: Sized,
33    {
34        BufferedSocket {
35            socket,
36            write_buf: WriteBuffer {
37                buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
38                bytes_written: 0,
39                bytes_flushed: 0,
40            },
41            read_buf: ReadBuffer {
42                read: BytesMut::new(),
43                available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
44            },
45        }
46    }
47
48    pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
49        self.read_buf.read(len, &mut self.socket).await
50    }
51
52    pub fn write_buffer(&self) -> &WriteBuffer {
53        &self.write_buf
54    }
55
56    pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
57        &mut self.write_buf
58    }
59
60    pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result<T, Error>
61    where
62        T: Decode<'de, ()>,
63    {
64        self.read_with(byte_len, ()).await
65    }
66
67    pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result<T, Error>
68    where
69        T: Decode<'de, C>,
70    {
71        T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
72    }
73
74    pub fn write<'en, T>(&mut self, value: T)
75    where
76        T: Encode<'en, ()>,
77    {
78        self.write_with(value, ())
79    }
80
81    pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
82    where
83        T: Encode<'en, C>,
84    {
85        value.encode_with(self.write_buf.buf_mut(), context);
86        self.write_buf.bytes_written = self.write_buf.buf.len();
87        self.write_buf.sanity_check();
88    }
89
90    pub async fn flush(&mut self) -> io::Result<()> {
91        while !self.write_buf.is_empty() {
92            let written = self.socket.write(self.write_buf.get()).await?;
93            self.write_buf.consume(written);
94            self.write_buf.sanity_check();
95        }
96
97        self.socket.flush().await?;
98
99        Ok(())
100    }
101
102    pub async fn shutdown(&mut self) -> io::Result<()> {
103        self.flush().await?;
104        self.socket.shutdown().await
105    }
106
107    pub fn shrink_buffers(&mut self) {
108        // Won't drop data still in the buffer.
109        self.write_buf.shrink();
110        self.read_buf.shrink();
111    }
112
113    pub fn into_inner(self) -> S {
114        self.socket
115    }
116
117    pub fn boxed(self) -> BufferedSocket<Box<dyn Socket>> {
118        BufferedSocket {
119            socket: Box::new(self.socket),
120            write_buf: self.write_buf,
121            read_buf: self.read_buf,
122        }
123    }
124}
125
126impl WriteBuffer {
127    fn sanity_check(&self) {
128        assert_ne!(self.buf.capacity(), 0);
129        assert!(self.bytes_written <= self.buf.len());
130        assert!(self.bytes_flushed <= self.bytes_written);
131    }
132
133    pub fn buf_mut(&mut self) -> &mut Vec<u8> {
134        self.buf.truncate(self.bytes_written);
135        self.sanity_check();
136        &mut self.buf
137    }
138
139    pub fn init_remaining_mut(&mut self) -> &mut [u8] {
140        self.buf.resize(self.buf.capacity(), 0);
141        self.sanity_check();
142        &mut self.buf[self.bytes_written..]
143    }
144
145    pub fn put_slice(&mut self, slice: &[u8]) {
146        // If we already have an initialized area that can fit the slice,
147        // don't change `self.buf.len()`
148        if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
149            dest.copy_from_slice(slice);
150        } else {
151            self.buf.truncate(self.bytes_written);
152            self.buf.extend_from_slice(slice);
153        }
154        self.advance(slice.len());
155        self.sanity_check();
156    }
157
158    pub fn advance(&mut self, amt: usize) {
159        let new_bytes_written = self
160            .bytes_written
161            .checked_add(amt)
162            .expect("self.bytes_written + amt overflowed");
163
164        assert!(new_bytes_written <= self.buf.len());
165
166        self.bytes_written = new_bytes_written;
167
168        self.sanity_check();
169    }
170
171    pub fn is_empty(&self) -> bool {
172        self.bytes_flushed >= self.bytes_written
173    }
174
175    pub fn is_full(&self) -> bool {
176        self.bytes_written == self.buf.len()
177    }
178
179    pub fn get(&self) -> &[u8] {
180        &self.buf[self.bytes_flushed..self.bytes_written]
181    }
182
183    pub fn get_mut(&mut self) -> &mut [u8] {
184        &mut self.buf[self.bytes_flushed..self.bytes_written]
185    }
186
187    pub fn shrink(&mut self) {
188        if self.bytes_flushed > 0 {
189            // Move any data that remains to be flushed to the beginning of the buffer,
190            // if necessary.
191            self.buf
192                .copy_within(self.bytes_flushed..self.bytes_written, 0);
193            self.bytes_written -= self.bytes_flushed;
194            self.bytes_flushed = 0
195        }
196
197        // Drop excess capacity.
198        self.buf
199            .truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE));
200        self.buf.shrink_to_fit();
201    }
202
203    fn consume(&mut self, amt: usize) {
204        let new_bytes_flushed = self
205            .bytes_flushed
206            .checked_add(amt)
207            .expect("self.bytes_flushed + amt overflowed");
208
209        assert!(new_bytes_flushed <= self.bytes_written);
210
211        self.bytes_flushed = new_bytes_flushed;
212
213        if self.bytes_flushed == self.bytes_written {
214            // Reset cursors to zero if we've consumed the whole buffer
215            self.bytes_flushed = 0;
216            self.bytes_written = 0;
217        }
218
219        self.sanity_check();
220    }
221}
222
223impl ReadBuffer {
224    async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
225        // Because of how `BytesMut` works, we should only be shifting capacity back and forth
226        // between `read` and `available` unless we have to read an oversize message.
227        while self.read.len() < len {
228            self.reserve(len - self.read.len());
229
230            let read = socket.read(&mut self.available).await?;
231
232            if read == 0 {
233                return Err(io::Error::new(
234                    io::ErrorKind::UnexpectedEof,
235                    format!(
236                        "expected to read {} bytes, got {} bytes at EOF",
237                        len,
238                        self.read.len()
239                    ),
240                ));
241            }
242
243            self.advance(read);
244        }
245
246        Ok(self.drain(len))
247    }
248
249    fn reserve(&mut self, amt: usize) {
250        if let Some(additional) = amt.checked_sub(self.available.capacity()) {
251            self.available.reserve(additional);
252        }
253    }
254
255    fn advance(&mut self, amt: usize) {
256        self.read.unsplit(self.available.split_to(amt));
257    }
258
259    fn drain(&mut self, amt: usize) -> BytesMut {
260        self.read.split_to(amt)
261    }
262
263    fn shrink(&mut self) {
264        if self.available.capacity() > DEFAULT_BUF_SIZE {
265            // `BytesMut` doesn't have a way to shrink its capacity,
266            // but we only use `available` for spare capacity anyway so we can just replace it.
267            //
268            // If `self.read` still contains data on the next call to `advance` then this might
269            // force a memcpy as they'll no longer be pointing to the same allocation,
270            // but that's kind of unavoidable.
271            //
272            // The `async-std` impl of `Socket` will also need to re-zero the buffer,
273            // but that's also kind of unavoidable.
274            //
275            // We should be warning the user not to call this often.
276            self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
277        }
278    }
279}