sqlx_build_trust_core/net/socket/
buffered.rs1use crate::net::Socket;
2use bytes::BytesMut;
3use std::{cmp, io};
4
5use crate::error::Error;
6
7use crate::io::{Decode, Encode};
8
9const DEFAULT_BUF_SIZE: usize = 8192;
11
12pub struct BufferedSocket<S> {
13 socket: S,
14 write_buf: WriteBuffer,
15 read_buf: ReadBuffer,
16}
17
18pub struct WriteBuffer {
19 buf: Vec<u8>,
20 bytes_written: usize,
21 bytes_flushed: usize,
22}
23
24pub struct ReadBuffer {
25 read: BytesMut,
26 available: BytesMut,
27}
28
29impl<S: Socket> BufferedSocket<S> {
30 pub fn new(socket: S) -> Self
31 where
32 S: Sized,
33 {
34 BufferedSocket {
35 socket,
36 write_buf: WriteBuffer {
37 buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
38 bytes_written: 0,
39 bytes_flushed: 0,
40 },
41 read_buf: ReadBuffer {
42 read: BytesMut::new(),
43 available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
44 },
45 }
46 }
47
48 pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
49 self.read_buf.read(len, &mut self.socket).await
50 }
51
52 pub fn write_buffer(&self) -> &WriteBuffer {
53 &self.write_buf
54 }
55
56 pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
57 &mut self.write_buf
58 }
59
60 pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result<T, Error>
61 where
62 T: Decode<'de, ()>,
63 {
64 self.read_with(byte_len, ()).await
65 }
66
67 pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result<T, Error>
68 where
69 T: Decode<'de, C>,
70 {
71 T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
72 }
73
74 pub fn write<'en, T>(&mut self, value: T)
75 where
76 T: Encode<'en, ()>,
77 {
78 self.write_with(value, ())
79 }
80
81 pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
82 where
83 T: Encode<'en, C>,
84 {
85 value.encode_with(self.write_buf.buf_mut(), context);
86 self.write_buf.bytes_written = self.write_buf.buf.len();
87 self.write_buf.sanity_check();
88 }
89
90 pub async fn flush(&mut self) -> io::Result<()> {
91 while !self.write_buf.is_empty() {
92 let written = self.socket.write(self.write_buf.get()).await?;
93 self.write_buf.consume(written);
94 self.write_buf.sanity_check();
95 }
96
97 self.socket.flush().await?;
98
99 Ok(())
100 }
101
102 pub async fn shutdown(&mut self) -> io::Result<()> {
103 self.flush().await?;
104 self.socket.shutdown().await
105 }
106
107 pub fn shrink_buffers(&mut self) {
108 self.write_buf.shrink();
110 self.read_buf.shrink();
111 }
112
113 pub fn into_inner(self) -> S {
114 self.socket
115 }
116
117 pub fn boxed(self) -> BufferedSocket<Box<dyn Socket>> {
118 BufferedSocket {
119 socket: Box::new(self.socket),
120 write_buf: self.write_buf,
121 read_buf: self.read_buf,
122 }
123 }
124}
125
126impl WriteBuffer {
127 fn sanity_check(&self) {
128 assert_ne!(self.buf.capacity(), 0);
129 assert!(self.bytes_written <= self.buf.len());
130 assert!(self.bytes_flushed <= self.bytes_written);
131 }
132
133 pub fn buf_mut(&mut self) -> &mut Vec<u8> {
134 self.buf.truncate(self.bytes_written);
135 self.sanity_check();
136 &mut self.buf
137 }
138
139 pub fn init_remaining_mut(&mut self) -> &mut [u8] {
140 self.buf.resize(self.buf.capacity(), 0);
141 self.sanity_check();
142 &mut self.buf[self.bytes_written..]
143 }
144
145 pub fn put_slice(&mut self, slice: &[u8]) {
146 if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
149 dest.copy_from_slice(slice);
150 } else {
151 self.buf.truncate(self.bytes_written);
152 self.buf.extend_from_slice(slice);
153 }
154 self.advance(slice.len());
155 self.sanity_check();
156 }
157
158 pub fn advance(&mut self, amt: usize) {
159 let new_bytes_written = self
160 .bytes_written
161 .checked_add(amt)
162 .expect("self.bytes_written + amt overflowed");
163
164 assert!(new_bytes_written <= self.buf.len());
165
166 self.bytes_written = new_bytes_written;
167
168 self.sanity_check();
169 }
170
171 pub fn is_empty(&self) -> bool {
172 self.bytes_flushed >= self.bytes_written
173 }
174
175 pub fn is_full(&self) -> bool {
176 self.bytes_written == self.buf.len()
177 }
178
179 pub fn get(&self) -> &[u8] {
180 &self.buf[self.bytes_flushed..self.bytes_written]
181 }
182
183 pub fn get_mut(&mut self) -> &mut [u8] {
184 &mut self.buf[self.bytes_flushed..self.bytes_written]
185 }
186
187 pub fn shrink(&mut self) {
188 if self.bytes_flushed > 0 {
189 self.buf
192 .copy_within(self.bytes_flushed..self.bytes_written, 0);
193 self.bytes_written -= self.bytes_flushed;
194 self.bytes_flushed = 0
195 }
196
197 self.buf
199 .truncate(cmp::max(self.bytes_written, DEFAULT_BUF_SIZE));
200 self.buf.shrink_to_fit();
201 }
202
203 fn consume(&mut self, amt: usize) {
204 let new_bytes_flushed = self
205 .bytes_flushed
206 .checked_add(amt)
207 .expect("self.bytes_flushed + amt overflowed");
208
209 assert!(new_bytes_flushed <= self.bytes_written);
210
211 self.bytes_flushed = new_bytes_flushed;
212
213 if self.bytes_flushed == self.bytes_written {
214 self.bytes_flushed = 0;
216 self.bytes_written = 0;
217 }
218
219 self.sanity_check();
220 }
221}
222
223impl ReadBuffer {
224 async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
225 while self.read.len() < len {
228 self.reserve(len - self.read.len());
229
230 let read = socket.read(&mut self.available).await?;
231
232 if read == 0 {
233 return Err(io::Error::new(
234 io::ErrorKind::UnexpectedEof,
235 format!(
236 "expected to read {} bytes, got {} bytes at EOF",
237 len,
238 self.read.len()
239 ),
240 ));
241 }
242
243 self.advance(read);
244 }
245
246 Ok(self.drain(len))
247 }
248
249 fn reserve(&mut self, amt: usize) {
250 if let Some(additional) = amt.checked_sub(self.available.capacity()) {
251 self.available.reserve(additional);
252 }
253 }
254
255 fn advance(&mut self, amt: usize) {
256 self.read.unsplit(self.available.split_to(amt));
257 }
258
259 fn drain(&mut self, amt: usize) -> BytesMut {
260 self.read.split_to(amt)
261 }
262
263 fn shrink(&mut self) {
264 if self.available.capacity() > DEFAULT_BUF_SIZE {
265 self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
277 }
278 }
279}