tracing_lv/
tokio.rs

1use pin_project::pin_project;
2use std::any::type_name;
3use std::io::Error;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
7use tracing::{info_span, Span};
8use tracing_core::field::debug;
9
10pub const VALUE: &'static str = "value";
11pub const INCREMENTAL: &'static str = "incremental";
12pub const WANT_WRITE: &'static str = "want_write";
13pub const POLL_RESULT: &'static str = "poll_result";
14
15#[pin_project]
16pub struct TLInstrumentedAsyncRead<T> {
17    #[pin]
18    inner: T,
19    span: Span,
20    total_read: u64,
21}
22impl<T> AsyncWrite for TLInstrumentedAsyncRead<T>
23where
24    T: AsyncWrite,
25{
26    fn poll_write(
27        self: Pin<&mut Self>,
28        cx: &mut Context<'_>,
29        buf: &[u8],
30    ) -> Poll<Result<usize, Error>> {
31        self.project().inner.poll_write(cx, buf)
32    }
33
34    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
35        self.project().inner.poll_flush(cx)
36    }
37
38    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
39        self.project().inner.poll_shutdown(cx)
40    }
41}
42
43impl<T> AsyncBufRead for TLInstrumentedAsyncRead<T>
44where
45    T: AsyncBufRead,
46{
47    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
48        self.project().inner.poll_fill_buf(cx)
49    }
50
51    fn consume(self: Pin<&mut Self>, amt: usize) {
52        self.project().inner.consume(amt)
53    }
54}
55
56impl<T> AsyncRead for TLInstrumentedAsyncRead<T>
57where
58    T: AsyncRead,
59{
60    #[track_caller]
61    fn poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &mut ReadBuf<'_>,
65    ) -> Poll<std::io::Result<()>> {
66        let me = self.project();
67        let entered = me.span.enter();
68        let prev_len = buf.filled().len();
69        let r = me.inner.poll_read(cx, buf);
70        if r.is_pending() {
71            return r;
72        }
73        let filled_len = buf.filled().len() - prev_len;
74        *me.total_read = *me.total_read + filled_len as u64;
75
76        let identifier = me.span.metadata().unwrap().callsite();
77        let value_field = me.span.field(VALUE).unwrap();
78        let incremental_field = me.span.field(INCREMENTAL).unwrap();
79        let poll_result_field = me.span.field(POLL_RESULT).unwrap();
80        me.span.record_all(
81            &tracing_core::field::FieldSet::new(&[VALUE, INCREMENTAL, POLL_RESULT], identifier)
82                .value_set(&[
83                    (
84                        &value_field,
85                        Some(&*me.total_read as &(dyn tracing_core::field::Value)),
86                    ),
87                    (
88                        &incremental_field,
89                        Some(&match &r {
90                            Poll::Ready(Ok(_)) => Some(filled_len),
91                            _ => None,
92                        } as &(dyn tracing_core::field::Value)),
93                    ),
94                    (
95                        &poll_result_field,
96                        Some(&debug(&r) as &(dyn tracing_core::field::Value)),
97                    ),
98                ]),
99        );
100        drop(entered);
101        r
102    }
103}
104
105pub trait TLAsyncReadExt: Sized {
106    fn instrument_read(
107        self,
108        name: &'static str,
109        total_size: Option<u64>,
110    ) -> TLInstrumentedAsyncRead<Self>;
111}
112
113impl<T> TLAsyncReadExt for T
114where
115    T: AsyncRead,
116{
117    fn instrument_read(
118        self,
119        name: &'static str,
120        total_size: Option<u64>,
121    ) -> TLInstrumentedAsyncRead<Self> {
122        TLInstrumentedAsyncRead {
123            inner: self,
124            span: info_span!(
125                "[t:AsyncRead]",
126                name,
127                "type" = type_name::<T>(),
128                value = 0,
129                incremental = 0,
130                poll_result = "",
131                total_size
132            ),
133            total_read: 0,
134        }
135    }
136}
137
138#[pin_project]
139pub struct TLAsyncWrite<T> {
140    #[pin]
141    inner: T,
142    span: Span,
143    total_write: u64,
144}
145impl<T> AsyncRead for TLAsyncWrite<T>
146where
147    T: AsyncRead,
148{
149    fn poll_read(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &mut ReadBuf<'_>,
153    ) -> Poll<std::io::Result<()>> {
154        self.project().inner.poll_read(cx, buf)
155    }
156}
157impl<T> AsyncBufRead for TLAsyncWrite<T>
158where
159    T: AsyncBufRead,
160{
161    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
162        self.project().inner.poll_fill_buf(cx)
163    }
164
165    fn consume(self: Pin<&mut Self>, amt: usize) {
166        self.project().inner.consume(amt)
167    }
168}
169
170impl<T> AsyncWrite for TLAsyncWrite<T>
171where
172    T: AsyncWrite,
173{
174    fn poll_write(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<Result<usize, Error>> {
179        let me = self.project();
180        let entered = me.span.enter();
181        let r = me.inner.poll_write(cx, buf);
182        if r.is_pending() {
183            return r;
184        }
185        if let Poll::Ready(Ok(count)) = &r {
186            *me.total_write = *me.total_write + *count as u64;
187        }
188        let identifier = me.span.metadata().unwrap().callsite();
189        let value_field = me.span.field(VALUE).unwrap();
190        let incremental_field = me.span.field(INCREMENTAL).unwrap();
191        let poll_result_field = me.span.field(POLL_RESULT).unwrap();
192        let want_write_field = me.span.field(WANT_WRITE).unwrap();
193        me.span.record_all(
194            &tracing_core::field::FieldSet::new(
195                &[VALUE, INCREMENTAL, POLL_RESULT, WANT_WRITE],
196                identifier,
197            )
198            .value_set(&[
199                (
200                    &value_field,
201                    Some(me.total_write as &(dyn tracing_core::field::Value)),
202                ),
203                (
204                    &incremental_field,
205                    Some(&match &r {
206                        Poll::Ready(Ok(n)) => Some(*n),
207                        _ => None,
208                    } as &(dyn tracing_core::field::Value)),
209                ),
210                (
211                    &poll_result_field,
212                    Some(&debug(&r) as &(dyn tracing_core::field::Value)),
213                ),
214                (
215                    &want_write_field,
216                    Some(&buf.len() as &(dyn tracing_core::field::Value)),
217                ),
218            ]),
219        );
220        drop(entered);
221        r
222    }
223
224    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
225        let me = self.project();
226        let _entered = me.span.enter();
227        me.inner.poll_flush(cx)
228    }
229
230    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
231        let me = self.project();
232        let _entered = me.span.enter();
233        me.inner.poll_shutdown(cx)
234    }
235}
236
237pub trait TLAsyncWriteExt: Sized {
238    fn instrument_write(self, name: &'static str, total_size: Option<u64>) -> TLAsyncWrite<Self>;
239}
240
241impl<T> TLAsyncWriteExt for T
242where
243    T: AsyncWrite,
244{
245    fn instrument_write(self, name: &'static str, total_size: Option<u64>) -> TLAsyncWrite<Self> {
246        TLAsyncWrite {
247            inner: self,
248            span: info_span!(
249                "[t:AsyncWrite]",
250                name,
251                "type" = type_name::<T>(),
252                value = 0,
253                incremental = 0,
254                want_write = "",
255                poll_result = "",
256                total_size
257            ),
258            total_write: 0,
259        }
260    }
261}