1use std::{io, io::BufWriter as IoBufWriter, mem::MaybeUninit};
4
5use bytes::{buf::Writer, BufMut, BytesMut};
6
7pub trait WriteExt: io::Write {
9 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]>;
15
16 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()>;
23}
24
25pub struct BufferedWriter<W> {
30 inner: W,
31 buffer: Vec<u8>,
32}
33
34impl<W> BufferedWriter<W> {
35 pub fn new(inner: W) -> Self {
37 Self {
38 inner,
39 buffer: Vec::new(),
40 }
41 }
42}
43
44impl<W> io::Write for BufferedWriter<W>
45where
46 W: io::Write,
47{
48 #[inline(always)]
49 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
50 self.inner.write(buf)
51 }
52
53 #[inline(always)]
54 fn flush(&mut self) -> io::Result<()> {
55 self.inner.flush()
56 }
57}
58
59impl<W> WriteExt for BufferedWriter<W>
60where
61 W: io::Write,
62{
63 #[inline(always)]
64 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
65 self.buffer.reserve_with(additional)
66 }
67
68 #[inline(always)]
69 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
70 self.buffer.flush_len(additional)?;
71 self.inner.write_all(&self.buffer)?;
72 self.buffer.clear();
73
74 Ok(())
75 }
76}
77
78impl WriteExt for Vec<u8> {
79 #[inline(always)]
80 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
81 self.reserve(additional);
82 Ok(&mut self.spare_capacity_mut()[..additional])
83 }
84
85 #[inline(always)]
86 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
87 unsafe {
88 let new_len = self.len() + additional;
89 self.set_len(new_len);
90 }
91
92 Ok(())
93 }
94}
95
96impl WriteExt for Writer<BytesMut> {
97 #[inline(always)]
98 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
99 let new_len = self.get_ref().len() + additional;
100 self.get_mut().set_len(new_len);
101 Ok(())
102 }
103
104 #[inline(always)]
105 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
106 self.get_mut().reserve(additional);
107 let ptr = unsafe { self.get_mut().chunk_mut().as_uninit_slice_mut() };
108 Ok(&mut ptr[..additional])
109 }
110}
111
112impl WriteExt for Writer<&mut BytesMut> {
113 #[inline(always)]
114 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
115 let new_len = self.get_ref().len() + additional;
116 self.get_mut().set_len(new_len);
117 Ok(())
118 }
119
120 #[inline(always)]
121 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
122 self.get_mut().reserve(additional);
123 let ptr = unsafe { self.get_mut().chunk_mut().as_uninit_slice_mut() };
124 Ok(&mut ptr[..additional])
125 }
126}
127
128impl<W: WriteExt + ?Sized> WriteExt for IoBufWriter<W> {
129 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
130 self.get_mut().reserve_with(additional)
131 }
132
133 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
134 self.get_mut().flush_len(additional)
135 }
136}
137
138impl<W: WriteExt + ?Sized> WriteExt for &mut W {
139 #[inline(always)]
140 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
141 (*self).flush_len(additional)
142 }
143
144 #[inline(always)]
145 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
146 (*self).reserve_with(additional)
147 }
148}
149
150impl<W: WriteExt + ?Sized> WriteExt for Box<W> {
151 #[inline(always)]
152 unsafe fn flush_len(&mut self, additional: usize) -> io::Result<()> {
153 (**self).flush_len(additional)
154 }
155
156 #[inline(always)]
157 fn reserve_with(&mut self, additional: usize) -> io::Result<&mut [MaybeUninit<u8>]> {
158 (**self).reserve_with(additional)
159 }
160}
161
162#[cfg(test)]
163mod test {
164 use std::io::Write;
165
166 use bytes::{BufMut, BytesMut};
167
168 use crate::writer::WriteExt;
169
170 #[test]
171 fn test_writer() {
172 let buffer = BytesMut::new();
173 let writer = &mut buffer.writer();
174
175 let buf = writer.reserve_with(20).unwrap_or_default();
176 assert_eq!(buf.len(), 20);
177 assert_eq!(writer.get_ref().capacity(), 20);
178
179 let data = b"Hello, World!";
180 writer.write_all(&data[..]).unwrap();
181 assert_eq!(writer.get_ref().capacity(), 20);
182 assert_eq!(writer.get_ref().as_ref(), &data[..]);
183 }
184}