1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use futures_core::ready;
8use futures_sink::Sink;
9use ordered_varint::Variable;
10use tokio::io::AsyncWrite;
11use transmog::Format;
12
13#[derive(Debug)]
23pub struct TransmogWriter<W, T, D, F> {
24 format: F,
25 writer: W,
26 pub(crate) written: usize,
27 pub(crate) buffer: Vec<u8>,
28 scratch_buffer: Vec<u8>,
29 from: PhantomData<T>,
30 dest: PhantomData<D>,
31}
32
33impl<W, T, D, F> Unpin for TransmogWriter<W, T, D, F> where W: Unpin {}
34
35impl<W, T, D, F> TransmogWriter<W, T, D, F> {
36 pub fn format(&self) -> &F {
40 &self.format
41 }
42
43 pub fn get_ref(&self) -> &W {
47 &self.writer
48 }
49
50 pub fn get_mut(&mut self) -> &mut W {
54 &mut self.writer
55 }
56
57 pub fn into_inner(self) -> (W, F) {
61 (self.writer, self.format)
62 }
63}
64
65impl<W, T, F> TransmogWriter<W, T, SyncDestination, F> {
66 pub fn new(writer: W, format: F) -> Self {
68 TransmogWriter {
69 format,
70 buffer: Vec::new(),
71 scratch_buffer: Vec::new(),
72 writer,
73 written: 0,
74 from: PhantomData,
75 dest: PhantomData,
76 }
77 }
78
79 pub fn default_for(format: F) -> Self
82 where
83 W: Default,
84 {
85 Self::new(W::default(), format)
86 }
87}
88
89impl<W, T, F> TransmogWriter<W, T, SyncDestination, F> {
90 pub fn for_async(self) -> TransmogWriter<W, T, AsyncDestination, F> {
94 self.make_for()
95 }
96}
97
98impl<W, T, D, F> TransmogWriter<W, T, D, F> {
99 pub(crate) fn make_for<D2>(self) -> TransmogWriter<W, T, D2, F> {
100 TransmogWriter {
101 format: self.format,
102 buffer: self.buffer,
103 writer: self.writer,
104 written: self.written,
105 from: self.from,
106 scratch_buffer: self.scratch_buffer,
107 dest: PhantomData,
108 }
109 }
110}
111
112impl<W, T, F> TransmogWriter<W, T, AsyncDestination, F> {
113 pub fn for_sync(self) -> TransmogWriter<W, T, SyncDestination, F> {
117 self.make_for()
118 }
119}
120
121#[derive(Debug)]
123pub struct AsyncDestination;
124
125#[derive(Debug)]
127pub struct SyncDestination;
128
129#[doc(hidden)]
130pub trait TransmogWriterFor<T, F>
131where
132 F: Format<'static, T>,
133{
134 fn append(&mut self, item: &T) -> Result<(), F::Error>;
135}
136
137impl<W, T, F> TransmogWriterFor<T, F> for TransmogWriter<W, T, AsyncDestination, F>
138where
139 F: Format<'static, T>,
140{
141 fn append(&mut self, item: &T) -> Result<(), F::Error> {
142 if let Some(serialized_length) = self.format.serialized_size(item)? {
143 let size = usize_to_u64(serialized_length)?;
144 size.encode_variable(&mut self.buffer)?;
145 self.format.serialize_into(item, &mut self.buffer)?;
146 } else {
147 self.scratch_buffer.truncate(0);
151 self.format.serialize_into(item, &mut self.scratch_buffer)?;
152
153 let size = usize_to_u64(self.scratch_buffer.len())?;
154 size.encode_variable(&mut self.buffer)?;
155 self.buffer.append(&mut self.scratch_buffer);
156 }
157 Ok(())
158 }
159}
160
161fn usize_to_u64(value: usize) -> Result<u64, std::io::Error> {
162 u64::try_from(value).map_err(|_| std::io::Error::from(std::io::ErrorKind::OutOfMemory))
163}
164
165impl<W, T, F> TransmogWriterFor<T, F> for TransmogWriter<W, T, SyncDestination, F>
166where
167 F: Format<'static, T>,
168{
169 fn append(&mut self, item: &T) -> Result<(), F::Error> {
170 self.format.serialize_into(item, &mut self.buffer)
171 }
172}
173
174impl<W, T, D, F> Sink<T> for TransmogWriter<W, T, D, F>
175where
176 F: Format<'static, T>,
177 W: AsyncWrite + Unpin,
178 Self: TransmogWriterFor<T, F>,
179{
180 type Error = F::Error;
181
182 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183 Poll::Ready(Ok(()))
184 }
185
186 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
187 self.append(&item)?;
188 Ok(())
189 }
190
191 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
192 let this = self.get_mut();
194
195 while this.written != this.buffer.len() {
197 let n =
198 ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buffer[this.written..]))?;
199 this.written += n;
200 }
201
202 this.buffer.clear();
204 this.written = 0;
205 Pin::new(&mut this.writer)
206 .poll_flush(cx)
207 .map_err(<F::Error as From<std::io::Error>>::from)
208 }
209
210 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
211 ready!(self.as_mut().poll_flush(cx))?;
212 Pin::new(&mut self.writer)
213 .poll_shutdown(cx)
214 .map_err(<F::Error as From<std::io::Error>>::from)
215 }
216}