write_monitor/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2//! `WriteMonitor` will wrap over a writer and monitor how many bytes are written to it.
3//! This is useful for showing progress of writes
4//! # Example
5//! ```
6//! use write_monitor::WriteMonitor;
7//! use std::io::Write;
8//! let mut buf = std::fs::File::create("somefile").unwrap();
9//! let mut wm = WriteMonitor::new(buf);
10//! let big_data = std::fs::read("Cargo.toml").unwrap();
11//! let big_data_len = big_data.len();
12//! let monitor = wm.monitor();
13//! std::thread::spawn(move || {
14//!     wm.write_all(&big_data).unwrap();
15//! });
16//! let mut last_written = 0;
17//! while monitor.bytes_written() < big_data_len as u64 {
18//!    let written = monitor.bytes_written();
19//!    if written != last_written {
20//!    println!("{} bytes written", written);
21//!    last_written = written;
22//!    }
23//!  std::thread::sleep(std::time::Duration::from_millis(100));
24//! }
25//! ```
26
27extern crate alloc;
28use alloc::sync::Arc;
29use core::sync::atomic::{AtomicU64, Ordering};
30
31#[cfg(any(feature = "futures", feature = "tokio"))]
32use core::{pin::Pin, task::Poll};
33
34#[cfg_attr(any(feature = "futures", feature = "tokio"), pin_project::pin_project)]
35#[derive(Debug, Clone)]
36pub struct WriteMonitor<W> {
37    #[cfg_attr(any(feature = "futures", feature = "tokio"), pin)]
38    inner: W,
39    bytes_written: Arc<AtomicU64>,
40}
41
42impl<W> WriteMonitor<W> {
43    pub fn new(inner: W) -> Self {
44        Self {
45            inner,
46            bytes_written: Arc::new(AtomicU64::new(0)),
47        }
48    }
49
50    pub fn bytes_written(&self) -> u64 {
51        self.bytes_written.load(Ordering::Acquire)
52    }
53
54    /// If the writer is dropped the monitor doesn't need to be dropped but it will stop updating.
55    pub fn monitor(&self) -> Monitor {
56        Monitor {
57            bytes_written: self.bytes_written.clone(),
58        }
59    }
60
61    pub fn into_inner(self) -> W {
62        self.inner
63    }
64}
65
66#[derive(Debug, Clone)]
67pub struct Monitor {
68    bytes_written: Arc<AtomicU64>,
69}
70
71impl Monitor {
72    pub fn bytes_written(&self) -> u64 {
73        self.bytes_written.load(Ordering::Acquire)
74    }
75
76    pub fn into_inner(self) -> Arc<AtomicU64> {
77        self.bytes_written
78    }
79}
80
81#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
82#[cfg(feature = "tokio")]
83impl<W: tokio::io::AsyncWrite + core::marker::Unpin> tokio::io::AsyncWrite for WriteMonitor<W> {
84    fn poll_write(
85        self: Pin<&mut Self>,
86        cx: &mut core::task::Context<'_>,
87        buf: &[u8],
88    ) -> core::task::Poll<std::io::Result<usize>> {
89        let ah = self.project();
90        let r = ah.inner.poll_write(cx, buf);
91        if let Poll::Ready(Ok(n)) = r {
92            ah.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
93        }
94        r
95    }
96
97    fn poll_flush(
98        self: Pin<&mut Self>,
99        cx: &mut core::task::Context<'_>,
100    ) -> core::task::Poll<std::io::Result<()>> {
101        let ah = self.project();
102        ah.inner.poll_flush(cx)
103    }
104
105    fn poll_shutdown(
106        self: Pin<&mut Self>,
107        cx: &mut core::task::Context<'_>,
108    ) -> core::task::Poll<std::io::Result<()>> {
109        let ah = self.project();
110        ah.inner.poll_shutdown(cx)
111    }
112}
113
114#[cfg(feature = "futures")]
115impl<W: futures::io::AsyncWrite + core::marker::Unpin> futures::io::AsyncWrite for WriteMonitor<W> {
116    fn poll_write(
117        self: Pin<&mut Self>,
118        cx: &mut core::task::Context<'_>,
119        buf: &[u8],
120    ) -> core::task::Poll<futures::io::Result<usize>> {
121        let ah = self.project();
122        let r = ah.inner.poll_write(cx, buf);
123        if let Poll::Ready(Ok(n)) = r {
124            ah.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
125        }
126        r
127    }
128    fn poll_flush(
129        self: Pin<&mut Self>,
130        cx: &mut core::task::Context<'_>,
131    ) -> core::task::Poll<futures::io::Result<()>> {
132        let ah = self.project();
133        ah.inner.poll_flush(cx)
134    }
135    fn poll_close(
136        self: Pin<&mut Self>,
137        cx: &mut core::task::Context<'_>,
138    ) -> core::task::Poll<futures::io::Result<()>> {
139        let ah = self.project();
140        ah.inner.poll_close(cx)
141    }
142}
143
144#[cfg(feature = "std")]
145impl<W: std::io::Write> std::io::Write for WriteMonitor<W> {
146    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
147        let r = std::io::Write::write(&mut self.inner, buf);
148        if let Ok(n) = r {
149            self.bytes_written.fetch_add(n as u64, Ordering::AcqRel);
150        }
151        r
152    }
153    fn flush(&mut self) -> std::io::Result<()> {
154        self.inner.flush()
155    }
156}
157
158#[test]
159pub fn test_write_monitor() {
160    use std::io::Write;
161    let mut buf = Vec::new();
162    let mut wm = WriteMonitor::new(&mut buf);
163    let big_data = b"Hello World";
164    let big_data_len = big_data.len();
165    let monitor = wm.monitor();
166    wm.write_all(big_data).unwrap();
167    assert_eq!(monitor.bytes_written(), big_data_len as u64);
168}