1#![cfg_attr(not(feature = "std"), no_std)]
2extern crate alloc;
28use alloc::sync::Arc;
29use core::sync::atomic::{AtomicU64, Ordering};
30
31#[cfg(any(feature = "futures", feature = "tokio"))]
32use core::{pin::Pin, task::Poll};
33
34#[cfg_attr(any(feature = "futures", feature = "tokio"), pin_project::pin_project)]
35#[derive(Debug, Clone)]
36pub struct WriteMonitor<W> {
37 #[cfg_attr(any(feature = "futures", feature = "tokio"), pin)]
38 inner: W,
39 bytes_written: Arc<AtomicU64>,
40}
41
42impl<W> WriteMonitor<W> {
43 pub fn new(inner: W) -> Self {
44 Self {
45 inner,
46 bytes_written: Arc::new(AtomicU64::new(0)),
47 }
48 }
49
50 pub fn bytes_written(&self) -> u64 {
51 self.bytes_written.load(Ordering::Acquire)
52 }
53
54 pub fn monitor(&self) -> Monitor {
56 Monitor {
57 bytes_written: self.bytes_written.clone(),
58 }
59 }
60
61 pub fn into_inner(self) -> W {
62 self.inner
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct Monitor {
68 bytes_written: Arc<AtomicU64>,
69}
70
71impl Monitor {
72 pub fn bytes_written(&self) -> u64 {
73 self.bytes_written.load(Ordering::Acquire)
74 }
75
76 pub fn into_inner(self) -> Arc<AtomicU64> {
77 self.bytes_written
78 }
79}
80
81#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
82#[cfg(feature = "tokio")]
83impl<W: tokio::io::AsyncWrite + core::marker::Unpin> tokio::io::AsyncWrite for WriteMonitor<W> {
84 fn poll_write(
85 self: Pin<&mut Self>,
86 cx: &mut core::task::Context<'_>,
87 buf: &[u8],
88 ) -> core::task::Poll<std::io::Result<usize>> {
89 let ah = self.project();
90 let r = ah.inner.poll_write(cx, buf);
91 if let Poll::Ready(Ok(n)) = r {
92 ah.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
93 }
94 r
95 }
96
97 fn poll_flush(
98 self: Pin<&mut Self>,
99 cx: &mut core::task::Context<'_>,
100 ) -> core::task::Poll<std::io::Result<()>> {
101 let ah = self.project();
102 ah.inner.poll_flush(cx)
103 }
104
105 fn poll_shutdown(
106 self: Pin<&mut Self>,
107 cx: &mut core::task::Context<'_>,
108 ) -> core::task::Poll<std::io::Result<()>> {
109 let ah = self.project();
110 ah.inner.poll_shutdown(cx)
111 }
112}
113
114#[cfg(feature = "futures")]
115impl<W: futures::io::AsyncWrite + core::marker::Unpin> futures::io::AsyncWrite for WriteMonitor<W> {
116 fn poll_write(
117 self: Pin<&mut Self>,
118 cx: &mut core::task::Context<'_>,
119 buf: &[u8],
120 ) -> core::task::Poll<futures::io::Result<usize>> {
121 let ah = self.project();
122 let r = ah.inner.poll_write(cx, buf);
123 if let Poll::Ready(Ok(n)) = r {
124 ah.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
125 }
126 r
127 }
128 fn poll_flush(
129 self: Pin<&mut Self>,
130 cx: &mut core::task::Context<'_>,
131 ) -> core::task::Poll<futures::io::Result<()>> {
132 let ah = self.project();
133 ah.inner.poll_flush(cx)
134 }
135 fn poll_close(
136 self: Pin<&mut Self>,
137 cx: &mut core::task::Context<'_>,
138 ) -> core::task::Poll<futures::io::Result<()>> {
139 let ah = self.project();
140 ah.inner.poll_close(cx)
141 }
142}
143
144#[cfg(feature = "std")]
145impl<W: std::io::Write> std::io::Write for WriteMonitor<W> {
146 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
147 let r = std::io::Write::write(&mut self.inner, buf);
148 if let Ok(n) = r {
149 self.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
150 }
151 r
152 }
153 fn flush(&mut self) -> std::io::Result<()> {
154 self.inner.flush()
155 }
156}
157
158#[test]
159pub fn test_write_monitor() {
160 use std::io::Write;
161 let mut buf = Vec::new();
162 let mut wm = WriteMonitor::new(&mut buf);
163 let big_data = b"Hello World";
164 let big_data_len = big_data.len();
165 let monitor = wm.monitor();
166 wm.write_all(big_data).unwrap();
167 assert_eq!(monitor.bytes_written(), big_data_len as u64);
168}