1use std::{
4 future::Future,
5 io::{self, IoSlice},
6 pin::Pin,
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use bytes::Bytes;
12use futures::{FutureExt, channel::oneshot};
13use tokio::{
14 io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf},
15 time::Sleep,
16};
17
18pub struct ConnectionBuilder {
20 read_timeout: Option<Duration>,
21 write_timeout: Option<Duration>,
22}
23
24impl ConnectionBuilder {
25 #[inline]
27 const fn new() -> Self {
28 Self {
29 read_timeout: None,
30 write_timeout: None,
31 }
32 }
33
34 #[inline]
36 pub const fn read_timeout(mut self, timeout: Option<Duration>) -> Self {
37 self.read_timeout = timeout;
38 self
39 }
40
41 #[inline]
43 pub const fn write_timeout(mut self, timeout: Option<Duration>) -> Self {
44 self.write_timeout = timeout;
45 self
46 }
47
48 pub fn build<IO>(self, io: IO) -> Connection<IO> {
50 let context = ConnectionContext::new(self.read_timeout, self.write_timeout);
51
52 Connection {
53 inner: io,
54 buffer: PrependBuffer::new(),
55 context: Box::pin(context),
56 }
57 }
58}
59
60pin_project_lite::pin_project! {
61 pub struct Connection<IO> {
63 #[pin]
64 inner: IO,
65 buffer: PrependBuffer,
66 context: Pin<Box<ConnectionContext>>,
67 }
68}
69
70impl Connection<()> {
71 #[inline]
73 pub const fn builder() -> ConnectionBuilder {
74 ConnectionBuilder::new()
75 }
76}
77
78impl<IO> Connection<IO> {
79 #[inline]
81 pub fn prepend(mut self, item: Bytes) -> Self {
82 self.buffer.prepend(item);
83 self
84 }
85}
86
87impl<IO> Connection<IO>
88where
89 IO: AsyncRead + AsyncWrite,
90{
91 pub fn split(mut self) -> (ConnectionReader<IO>, ConnectionWriter<IO>) {
93 let buffer = self.buffer.take();
94
95 let (r, w) = tokio::io::split(self);
96
97 let reader = ConnectionReader { inner: r, buffer };
98
99 let writer = ConnectionWriter { inner: w };
100
101 (reader, writer)
102 }
103}
104
105impl<IO> Connection<IO>
106where
107 IO: AsyncRead + AsyncWrite + Send + 'static,
108{
109 pub fn upgrade(self) -> Upgraded {
113 Upgraded {
114 inner: Box::pin(self.inner),
115 buffer: self.buffer,
116 }
117 }
118}
119
120impl<IO> AsyncRead for Connection<IO>
121where
122 IO: AsyncRead,
123{
124 fn poll_read(
125 self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 buf: &mut ReadBuf<'_>,
128 ) -> Poll<io::Result<()>> {
129 let this = self.project();
130
131 if !this.buffer.is_empty() {
132 this.buffer.read(buf);
134
135 return Poll::Ready(Ok(()));
136 }
137
138 let res = this.inner.poll_read(cx, buf);
139
140 if res.is_ready() {
141 this.context.as_mut().reset_read_timeout();
142 } else {
143 this.context.as_mut().check_read_timeout(cx)?;
144 }
145
146 res
147 }
148}
149
150impl<IO> AsyncWrite for Connection<IO>
151where
152 IO: AsyncWrite,
153{
154 fn poll_write(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &[u8],
158 ) -> Poll<io::Result<usize>> {
159 let this = self.project();
160
161 let res = this.inner.poll_write(cx, buf);
162
163 if res.is_ready() {
164 this.context.as_mut().reset_write_timeout();
165 } else {
166 this.context.as_mut().check_write_timeout(cx)?;
167 }
168
169 res
170 }
171
172 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 let this = self.project();
174
175 let res = this.inner.poll_flush(cx);
176
177 if res.is_ready() {
178 this.context.as_mut().reset_write_timeout();
179 } else {
180 this.context.as_mut().check_write_timeout(cx)?;
181 }
182
183 res
184 }
185
186 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
187 let this = self.project();
188
189 let res = this.inner.poll_shutdown(cx);
190
191 if res.is_ready() {
192 this.context.as_mut().reset_write_timeout();
193 } else {
194 this.context.as_mut().check_write_timeout(cx)?;
195 }
196
197 res
198 }
199
200 fn poll_write_vectored(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 bufs: &[IoSlice<'_>],
204 ) -> Poll<io::Result<usize>> {
205 let this = self.project();
206
207 let res = this.inner.poll_write_vectored(cx, bufs);
208
209 if res.is_ready() {
210 this.context.as_mut().reset_write_timeout();
211 } else {
212 this.context.as_mut().check_write_timeout(cx)?;
213 }
214
215 res
216 }
217
218 #[inline]
219 fn is_write_vectored(&self) -> bool {
220 self.inner.is_write_vectored()
221 }
222}
223
224pin_project_lite::pin_project! {
225 struct ConnectionContext {
227 read_timeout: Option<Duration>,
228 write_timeout: Option<Duration>,
229 #[pin]
230 read_timeout_delay: Option<Sleep>,
231 #[pin]
232 write_timeout_delay: Option<Sleep>,
233 }
234}
235
236impl ConnectionContext {
237 #[inline]
239 const fn new(read_timeout: Option<Duration>, write_timeout: Option<Duration>) -> Self {
240 Self {
241 read_timeout,
242 write_timeout,
243 read_timeout_delay: None,
244 write_timeout_delay: None,
245 }
246 }
247
248 #[inline]
250 fn reset_read_timeout(self: Pin<&mut Self>) {
251 let mut this = self.project();
252
253 this.read_timeout_delay.set(None);
254 }
255
256 fn check_read_timeout(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
258 let mut this = self.project();
259
260 if let Some(timeout) = *this.read_timeout {
261 if this.read_timeout_delay.is_none() {
262 this.read_timeout_delay
263 .set(Some(tokio::time::sleep(timeout)));
264 }
265
266 if let Some(timeout) = this.read_timeout_delay.as_pin_mut() {
267 if timeout.poll(cx).is_ready() {
268 return Err(io::Error::new(io::ErrorKind::TimedOut, "read timeout"));
269 }
270 }
271 }
272
273 Ok(())
274 }
275
276 #[inline]
278 fn reset_write_timeout(self: Pin<&mut Self>) {
279 let mut this = self.project();
280
281 this.write_timeout_delay.set(None);
282 }
283
284 fn check_write_timeout(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
286 let mut this = self.project();
287
288 if let Some(timeout) = *this.write_timeout {
289 if this.write_timeout_delay.is_none() {
290 this.write_timeout_delay
291 .set(Some(tokio::time::sleep(timeout)));
292 }
293
294 if let Some(timeout) = this.write_timeout_delay.as_pin_mut() {
295 if timeout.poll(cx).is_ready() {
296 return Err(io::Error::new(io::ErrorKind::TimedOut, "write timeout"));
297 }
298 }
299 }
300
301 Ok(())
302 }
303}
304
305pub struct ConnectionReader<IO> {
307 inner: ReadHalf<Connection<IO>>,
308 buffer: PrependBuffer,
309}
310
311impl<IO> ConnectionReader<IO> {
312 #[inline]
314 pub fn prepend(mut self, item: Bytes) -> Self {
315 self.buffer.prepend(item);
316 self
317 }
318}
319
320impl<IO> ConnectionReader<IO>
321where
322 IO: Unpin,
323{
324 pub fn join(self, writer: ConnectionWriter<IO>) -> Connection<IO> {
326 let mut connection = self.inner.unsplit(writer.inner);
327
328 connection.buffer = self.buffer;
329 connection
330 }
331}
332
333impl<IO> AsyncRead for ConnectionReader<IO>
334where
335 IO: AsyncRead,
336{
337 fn poll_read(
338 mut self: Pin<&mut Self>,
339 cx: &mut Context<'_>,
340 buf: &mut ReadBuf<'_>,
341 ) -> Poll<io::Result<()>> {
342 if !self.buffer.is_empty() {
343 self.buffer.read(buf);
345
346 return Poll::Ready(Ok(()));
347 }
348
349 let pinned = Pin::new(&mut self.inner);
350
351 pinned.poll_read(cx, buf)
352 }
353}
354
355pub struct ConnectionWriter<IO> {
357 inner: WriteHalf<Connection<IO>>,
358}
359
360impl<IO> AsyncWrite for ConnectionWriter<IO>
361where
362 IO: AsyncWrite,
363{
364 #[inline]
365 fn poll_write(
366 mut self: Pin<&mut Self>,
367 cx: &mut Context<'_>,
368 buf: &[u8],
369 ) -> Poll<io::Result<usize>> {
370 AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
371 }
372
373 #[inline]
374 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
375 AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
376 }
377
378 #[inline]
379 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
380 AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
381 }
382
383 #[inline]
384 fn poll_write_vectored(
385 mut self: Pin<&mut Self>,
386 cx: &mut Context<'_>,
387 bufs: &[IoSlice<'_>],
388 ) -> Poll<io::Result<usize>> {
389 AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
390 }
391
392 #[inline]
393 fn is_write_vectored(&self) -> bool {
394 self.inner.is_write_vectored()
395 }
396}
397
398trait AsyncReadWrite: AsyncRead + AsyncWrite {}
400
401impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite {}
402
403pub struct Upgraded {
405 inner: Pin<Box<dyn AsyncReadWrite + Send>>,
406 buffer: PrependBuffer,
407}
408
409impl AsyncRead for Upgraded {
410 fn poll_read(
411 mut self: Pin<&mut Self>,
412 cx: &mut Context<'_>,
413 buf: &mut ReadBuf<'_>,
414 ) -> Poll<io::Result<()>> {
415 if !self.buffer.is_empty() {
416 self.buffer.read(buf);
418
419 return Poll::Ready(Ok(()));
420 }
421
422 let pinned = Pin::new(&mut self.inner);
423
424 pinned.poll_read(cx, buf)
425 }
426}
427
428impl AsyncWrite for Upgraded {
429 #[inline]
430 fn poll_write(
431 mut self: Pin<&mut Self>,
432 cx: &mut Context<'_>,
433 buf: &[u8],
434 ) -> Poll<io::Result<usize>> {
435 AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
436 }
437
438 #[inline]
439 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
440 AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
441 }
442
443 #[inline]
444 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
445 AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
446 }
447
448 #[inline]
449 fn poll_write_vectored(
450 mut self: Pin<&mut Self>,
451 cx: &mut Context<'_>,
452 bufs: &[IoSlice<'_>],
453 ) -> Poll<io::Result<usize>> {
454 AsyncWrite::poll_write_vectored(Pin::new(&mut self.inner), cx, bufs)
455 }
456
457 #[inline]
458 fn is_write_vectored(&self) -> bool {
459 self.inner.is_write_vectored()
460 }
461}
462
463pub struct UpgradeFuture {
465 inner: oneshot::Receiver<Upgraded>,
466}
467
468impl UpgradeFuture {
469 pub fn new() -> (Self, UpgradeRequest) {
471 let (tx, rx) = oneshot::channel();
472
473 let tx = UpgradeRequest { inner: tx };
474 let rx = Self { inner: rx };
475
476 (rx, tx)
477 }
478}
479
480impl Future for UpgradeFuture {
481 type Output = io::Result<Upgraded>;
482
483 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
484 self.inner
485 .poll_unpin(cx)
486 .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
487 }
488}
489
490pub struct UpgradeRequest {
492 inner: oneshot::Sender<Upgraded>,
493}
494
495impl UpgradeRequest {
496 pub fn resolve(self, connection: Upgraded) {
498 let _ = self.inner.send(connection);
499 }
500}
501
502struct PrependBuffer {
504 inner: Vec<Bytes>,
505}
506
507impl PrependBuffer {
508 #[inline]
510 const fn new() -> Self {
511 Self { inner: Vec::new() }
512 }
513
514 fn prepend(&mut self, item: Bytes) {
516 if !item.is_empty() {
517 self.inner.push(item);
518 }
519 }
520
521 fn read(&mut self, buf: &mut ReadBuf<'_>) {
523 if let Some(chunk) = self.inner.last_mut() {
524 let available = chunk.len();
525
526 let take = available.min(buf.remaining());
527
528 buf.put_slice(&chunk.split_to(take));
529
530 if chunk.is_empty() {
531 self.inner.pop();
532 }
533 }
534 }
535
536 #[inline]
538 fn take(&mut self) -> Self {
539 Self {
540 inner: std::mem::take(&mut self.inner),
541 }
542 }
543
544 #[inline]
546 fn is_empty(&self) -> bool {
547 self.inner.is_empty()
548 }
549}