1use pin_project::pin_project;
2use std::any::type_name;
3use std::io::Error;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
7use tracing::{info_span, Span};
8use tracing_core::field::debug;
9
10pub const VALUE: &'static str = "value";
11pub const INCREMENTAL: &'static str = "incremental";
12pub const WANT_WRITE: &'static str = "want_write";
13pub const POLL_RESULT: &'static str = "poll_result";
14
15#[pin_project]
16pub struct TLInstrumentedAsyncRead<T> {
17 #[pin]
18 inner: T,
19 span: Span,
20 total_read: u64,
21}
22impl<T> AsyncWrite for TLInstrumentedAsyncRead<T>
23where
24 T: AsyncWrite,
25{
26 fn poll_write(
27 self: Pin<&mut Self>,
28 cx: &mut Context<'_>,
29 buf: &[u8],
30 ) -> Poll<Result<usize, Error>> {
31 self.project().inner.poll_write(cx, buf)
32 }
33
34 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
35 self.project().inner.poll_flush(cx)
36 }
37
38 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
39 self.project().inner.poll_shutdown(cx)
40 }
41}
42
43impl<T> AsyncBufRead for TLInstrumentedAsyncRead<T>
44where
45 T: AsyncBufRead,
46{
47 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
48 self.project().inner.poll_fill_buf(cx)
49 }
50
51 fn consume(self: Pin<&mut Self>, amt: usize) {
52 self.project().inner.consume(amt)
53 }
54}
55
56impl<T> AsyncRead for TLInstrumentedAsyncRead<T>
57where
58 T: AsyncRead,
59{
60 #[track_caller]
61 fn poll_read(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &mut ReadBuf<'_>,
65 ) -> Poll<std::io::Result<()>> {
66 let me = self.project();
67 let entered = me.span.enter();
68 let prev_len = buf.filled().len();
69 let r = me.inner.poll_read(cx, buf);
70 if r.is_pending() {
71 return r;
72 }
73 let filled_len = buf.filled().len() - prev_len;
74 *me.total_read = *me.total_read + filled_len as u64;
75
76 let identifier = me.span.metadata().unwrap().callsite();
77 let value_field = me.span.field(VALUE).unwrap();
78 let incremental_field = me.span.field(INCREMENTAL).unwrap();
79 let poll_result_field = me.span.field(POLL_RESULT).unwrap();
80 me.span.record_all(
81 &tracing_core::field::FieldSet::new(&[VALUE, INCREMENTAL, POLL_RESULT], identifier)
82 .value_set(&[
83 (
84 &value_field,
85 Some(&*me.total_read as &(dyn tracing_core::field::Value)),
86 ),
87 (
88 &incremental_field,
89 Some(&match &r {
90 Poll::Ready(Ok(_)) => Some(filled_len),
91 _ => None,
92 } as &(dyn tracing_core::field::Value)),
93 ),
94 (
95 &poll_result_field,
96 Some(&debug(&r) as &(dyn tracing_core::field::Value)),
97 ),
98 ]),
99 );
100 drop(entered);
101 r
102 }
103}
104
105pub trait TLAsyncReadExt: Sized {
106 fn instrument_read(
107 self,
108 name: &'static str,
109 total_size: Option<u64>,
110 ) -> TLInstrumentedAsyncRead<Self>;
111}
112
113impl<T> TLAsyncReadExt for T
114where
115 T: AsyncRead,
116{
117 fn instrument_read(
118 self,
119 name: &'static str,
120 total_size: Option<u64>,
121 ) -> TLInstrumentedAsyncRead<Self> {
122 TLInstrumentedAsyncRead {
123 inner: self,
124 span: info_span!(
125 "[t:AsyncRead]",
126 name,
127 "type" = type_name::<T>(),
128 value = 0,
129 incremental = 0,
130 poll_result = "",
131 total_size
132 ),
133 total_read: 0,
134 }
135 }
136}
137
138#[pin_project]
139pub struct TLAsyncWrite<T> {
140 #[pin]
141 inner: T,
142 span: Span,
143 total_write: u64,
144}
145impl<T> AsyncRead for TLAsyncWrite<T>
146where
147 T: AsyncRead,
148{
149 fn poll_read(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 buf: &mut ReadBuf<'_>,
153 ) -> Poll<std::io::Result<()>> {
154 self.project().inner.poll_read(cx, buf)
155 }
156}
157impl<T> AsyncBufRead for TLAsyncWrite<T>
158where
159 T: AsyncBufRead,
160{
161 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
162 self.project().inner.poll_fill_buf(cx)
163 }
164
165 fn consume(self: Pin<&mut Self>, amt: usize) {
166 self.project().inner.consume(amt)
167 }
168}
169
170impl<T> AsyncWrite for TLAsyncWrite<T>
171where
172 T: AsyncWrite,
173{
174 fn poll_write(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<Result<usize, Error>> {
179 let me = self.project();
180 let entered = me.span.enter();
181 let r = me.inner.poll_write(cx, buf);
182 if r.is_pending() {
183 return r;
184 }
185 if let Poll::Ready(Ok(count)) = &r {
186 *me.total_write = *me.total_write + *count as u64;
187 }
188 let identifier = me.span.metadata().unwrap().callsite();
189 let value_field = me.span.field(VALUE).unwrap();
190 let incremental_field = me.span.field(INCREMENTAL).unwrap();
191 let poll_result_field = me.span.field(POLL_RESULT).unwrap();
192 let want_write_field = me.span.field(WANT_WRITE).unwrap();
193 me.span.record_all(
194 &tracing_core::field::FieldSet::new(
195 &[VALUE, INCREMENTAL, POLL_RESULT, WANT_WRITE],
196 identifier,
197 )
198 .value_set(&[
199 (
200 &value_field,
201 Some(me.total_write as &(dyn tracing_core::field::Value)),
202 ),
203 (
204 &incremental_field,
205 Some(&match &r {
206 Poll::Ready(Ok(n)) => Some(*n),
207 _ => None,
208 } as &(dyn tracing_core::field::Value)),
209 ),
210 (
211 &poll_result_field,
212 Some(&debug(&r) as &(dyn tracing_core::field::Value)),
213 ),
214 (
215 &want_write_field,
216 Some(&buf.len() as &(dyn tracing_core::field::Value)),
217 ),
218 ]),
219 );
220 drop(entered);
221 r
222 }
223
224 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
225 let me = self.project();
226 let _entered = me.span.enter();
227 me.inner.poll_flush(cx)
228 }
229
230 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
231 let me = self.project();
232 let _entered = me.span.enter();
233 me.inner.poll_shutdown(cx)
234 }
235}
236
237pub trait TLAsyncWriteExt: Sized {
238 fn instrument_write(self, name: &'static str, total_size: Option<u64>) -> TLAsyncWrite<Self>;
239}
240
241impl<T> TLAsyncWriteExt for T
242where
243 T: AsyncWrite,
244{
245 fn instrument_write(self, name: &'static str, total_size: Option<u64>) -> TLAsyncWrite<Self> {
246 TLAsyncWrite {
247 inner: self,
248 span: info_span!(
249 "[t:AsyncWrite]",
250 name,
251 "type" = type_name::<T>(),
252 value = 0,
253 incremental = 0,
254 want_write = "",
255 poll_result = "",
256 total_size
257 ),
258 total_write: 0,
259 }
260 }
261}