1use ::core::future::Future;
2use ::core::mem;
3use ::core::str;
4
5use std::sync::Arc;
6
7use leb128_tokio::{AsyncReadLeb128, Leb128DecoderU32, Leb128Encoder};
8use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
9use tokio_util::bytes::{BufMut as _, Bytes, BytesMut};
10use tokio_util::codec::{Decoder, Encoder};
11
12pub trait AsyncReadCore: AsyncRead {
13 #[cfg_attr(
15 feature = "tracing",
16 tracing::instrument(level = "trace", ret, skip_all, fields(ty = "name"))
17 )]
18 fn read_core_name(&mut self, s: &mut String) -> impl Future<Output = std::io::Result<()>>
19 where
20 Self: Unpin + Sized,
21 {
22 async move {
23 let n = self.read_u32_leb128().await?;
24 s.reserve(n.try_into().unwrap_or(usize::MAX));
25 self.take(n.into()).read_to_string(s).await?;
26 Ok(())
27 }
28 }
29}
30
31impl<T: AsyncRead> AsyncReadCore for T {}
32
33pub trait AsyncWriteCore: AsyncWrite {
34 #[cfg_attr(
36 feature = "tracing",
37 tracing::instrument(level = "trace", ret, skip_all, fields(ty = "name"))
38 )]
39 fn write_core_name(&mut self, s: &str) -> impl Future<Output = std::io::Result<()>>
40 where
41 Self: Unpin,
42 {
43 async move {
44 let mut buf = BytesMut::with_capacity(5usize.saturating_add(s.len()));
45 CoreNameEncoder.encode(s, &mut buf)?;
46 self.write_all(&buf).await
47 }
48 }
49}
50
51impl<T: AsyncWrite> AsyncWriteCore for T {}
52
53#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
55pub struct CoreNameEncoder;
56
57impl<T: AsRef<str>> Encoder<T> for CoreNameEncoder {
58 type Error = std::io::Error;
59
60 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
61 let item = item.as_ref();
62 let len = item.len();
63 let n: u32 = len
64 .try_into()
65 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
66 dst.reserve(len + 5 - n.leading_zeros() as usize / 7);
67 Leb128Encoder.encode(n, dst)?;
68 dst.put(item.as_bytes());
69 Ok(())
70 }
71}
72
73#[derive(Debug, Default)]
75pub struct CoreNameDecoder(CoreVecDecoderBytes);
76
77impl Decoder for CoreNameDecoder {
78 type Item = String;
79 type Error = std::io::Error;
80
81 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
82 let Some(buf) = self.0.decode(src)? else {
83 return Ok(None);
84 };
85 let s = str::from_utf8(&buf)
86 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
87 Ok(Some(s.to_string()))
88 }
89}
90
91pub struct CoreVecEncoder<E>(pub E);
93
94impl<E, T, const N: usize> Encoder<[T; N]> for CoreVecEncoder<E>
95where
96 E: Encoder<T>,
97 std::io::Error: From<E::Error>,
98{
99 type Error = std::io::Error;
100
101 fn encode(&mut self, item: [T; N], dst: &mut BytesMut) -> Result<(), Self::Error> {
102 dst.reserve(5 + N);
103 let len = u32::try_from(N)
104 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
105 Leb128Encoder.encode(len, dst)?;
106 for item in item {
107 self.0.encode(item, dst)?;
108 }
109 Ok(())
110 }
111}
112
113impl<E, T> Encoder<Vec<T>> for CoreVecEncoder<E>
114where
115 E: Encoder<T>,
116 std::io::Error: From<E::Error>,
117{
118 type Error = std::io::Error;
119
120 fn encode(&mut self, item: Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
121 let len = item.len();
122 dst.reserve(5 + len);
123 let len = u32::try_from(len)
124 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
125 Leb128Encoder.encode(len, dst)?;
126 for item in item {
127 self.0.encode(item, dst)?;
128 }
129 Ok(())
130 }
131}
132
133impl<E, T> Encoder<Box<[T]>> for CoreVecEncoder<E>
134where
135 E: Encoder<T>,
136 std::io::Error: From<E::Error>,
137{
138 type Error = std::io::Error;
139
140 fn encode(&mut self, item: Box<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
141 self.encode(Vec::from(item), dst)
142 }
143}
144
145impl<'a, E, T, const N: usize> Encoder<&'a [T; N]> for CoreVecEncoder<E>
146where
147 E: Encoder<&'a T>,
148 std::io::Error: From<E::Error>,
149{
150 type Error = std::io::Error;
151
152 fn encode(&mut self, item: &'a [T; N], dst: &mut BytesMut) -> Result<(), Self::Error> {
153 dst.reserve(5 + N);
154 let len = u32::try_from(N)
155 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
156 Leb128Encoder.encode(len, dst)?;
157 for item in item {
158 self.0.encode(item, dst)?;
159 }
160 Ok(())
161 }
162}
163
164impl<'a, E, T> Encoder<&'a [T]> for CoreVecEncoder<E>
165where
166 E: Encoder<&'a T>,
167 std::io::Error: From<E::Error>,
168{
169 type Error = std::io::Error;
170
171 fn encode(&mut self, item: &'a [T], dst: &mut BytesMut) -> Result<(), Self::Error> {
172 let len = item.len();
173 dst.reserve(5 + len);
174 let len = u32::try_from(len)
175 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
176 Leb128Encoder.encode(len, dst)?;
177 for item in item {
178 self.0.encode(item, dst)?;
179 }
180 Ok(())
181 }
182}
183
184impl<'a, 'b, E, T> Encoder<&'a &'b [T]> for CoreVecEncoder<E>
185where
186 E: Encoder<&'b T>,
187 std::io::Error: From<E::Error>,
188{
189 type Error = std::io::Error;
190
191 fn encode(&mut self, item: &'a &'b [T], dst: &mut BytesMut) -> Result<(), Self::Error> {
192 self.encode(*item, dst)
193 }
194}
195
196impl<'a, E, T> Encoder<&'a Vec<T>> for CoreVecEncoder<E>
197where
198 E: Encoder<&'a T>,
199 std::io::Error: From<E::Error>,
200{
201 type Error = std::io::Error;
202
203 fn encode(&mut self, item: &'a Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
204 self.encode(item.as_slice(), dst)
205 }
206}
207
208impl<'a, 'b, E, T> Encoder<&'a &'b Vec<T>> for CoreVecEncoder<E>
209where
210 E: Encoder<&'b T>,
211 std::io::Error: From<E::Error>,
212{
213 type Error = std::io::Error;
214
215 fn encode(&mut self, item: &'a &'b Vec<T>, dst: &mut BytesMut) -> Result<(), Self::Error> {
216 self.encode(item.as_slice(), dst)
217 }
218}
219
220impl<'a, E, T> Encoder<&'a Box<[T]>> for CoreVecEncoder<E>
221where
222 E: Encoder<&'a T>,
223 std::io::Error: From<E::Error>,
224{
225 type Error = std::io::Error;
226
227 fn encode(&mut self, item: &'a Box<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
228 let item: &[T] = item.as_ref();
229 self.encode(item, dst)
230 }
231}
232
233impl<E, T> Encoder<Arc<[T]>> for CoreVecEncoder<E>
234where
235 for<'a> E: Encoder<&'a T>,
236 for<'a> std::io::Error: From<<E as Encoder<&'a T>>::Error>,
237{
238 type Error = std::io::Error;
239
240 fn encode(&mut self, item: Arc<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
241 let item = item.as_ref();
242 let len = item.len();
243 dst.reserve(5 + len);
244 let len = u32::try_from(len)
245 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
246 Leb128Encoder.encode(len, dst)?;
247 for item in item {
248 self.0.encode(item, dst)?;
249 }
250 Ok(())
251 }
252}
253
254impl<'a, E, T> Encoder<&'a Arc<[T]>> for CoreVecEncoder<E>
255where
256 E: Encoder<&'a T>,
257 std::io::Error: From<E::Error>,
258{
259 type Error = std::io::Error;
260
261 fn encode(&mut self, item: &'a Arc<[T]>, dst: &mut BytesMut) -> Result<(), Self::Error> {
262 let item: &[T] = item.as_ref();
263 self.encode(item, dst)
264 }
265}
266
267#[derive(Debug)]
269pub struct CoreVecDecoder<T: Decoder> {
270 dec: T,
271 ret: Vec<T::Item>,
272 cap: usize,
273}
274
275impl<T> CoreVecDecoder<T>
276where
277 T: Decoder,
278{
279 pub fn new(decoder: T) -> Self {
280 Self {
281 dec: decoder,
282 ret: Vec::default(),
283 cap: 0,
284 }
285 }
286
287 pub fn into_inner(self) -> T {
288 self.dec
289 }
290}
291
292impl<T> Default for CoreVecDecoder<T>
293where
294 T: Decoder + Default,
295{
296 fn default() -> Self {
297 Self::new(T::default())
298 }
299}
300
301impl<T> Decoder for CoreVecDecoder<T>
302where
303 T: Decoder,
304{
305 type Item = Vec<T::Item>;
306 type Error = T::Error;
307
308 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
309 if self.cap == 0 {
310 let Some(len) = Leb128DecoderU32.decode(src)? else {
311 return Ok(None);
312 };
313 if len == 0 {
314 return Ok(Some(Vec::default()));
315 }
316 let len = len
317 .try_into()
318 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
319 self.ret = Vec::with_capacity(len);
320 self.cap = len;
321 }
322 while self.cap > 0 {
323 let Some(v) = self.dec.decode(src)? else {
324 return Ok(None);
325 };
326 self.ret.push(v);
327 self.cap -= 1;
328 }
329 Ok(Some(mem::take(&mut self.ret)))
330 }
331}
332
333#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
336pub struct CoreVecEncoderBytes;
337
338impl<T: AsRef<[u8]>> Encoder<T> for CoreVecEncoderBytes {
339 type Error = std::io::Error;
340
341 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
342 let item = item.as_ref();
343 let n = item.len();
344 let n = u32::try_from(n)
345 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
346 dst.reserve(item.len().saturating_add(5));
347 Leb128Encoder.encode(n, dst)?;
348 dst.extend_from_slice(item);
349 Ok(())
350 }
351}
352
353#[derive(Debug, Default)]
356pub struct CoreVecDecoderBytes(usize);
357
358impl Decoder for CoreVecDecoderBytes {
359 type Item = Bytes;
360 type Error = std::io::Error;
361
362 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
363 if self.0 == 0 {
364 let Some(len) = Leb128DecoderU32.decode(src)? else {
365 return Ok(None);
366 };
367 if len == 0 {
368 return Ok(Some(Bytes::default()));
369 }
370 let len = len
371 .try_into()
372 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
373 self.0 = len;
374 }
375 let n = self.0.saturating_sub(src.len());
376 if n > 0 {
377 src.reserve(n);
378 return Ok(None);
379 }
380 let buf = src.split_to(self.0);
381 self.0 = 0;
382 Ok(Some(buf.freeze()))
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use futures::{SinkExt as _, TryStreamExt as _};
389 use tokio_util::codec::{FramedRead, FramedWrite};
390 use tracing::trace;
391
392 use super::*;
393
394 #[test_log::test(tokio::test)]
395 async fn string() {
396 let mut s = String::new();
397 "\x04test"
398 .as_bytes()
399 .read_core_name(&mut s)
400 .await
401 .expect("failed to read string");
402 assert_eq!(s, "test");
403
404 let mut buf = vec![];
405 buf.write_core_name("test")
406 .await
407 .expect("failed to write string");
408 assert_eq!(buf, b"\x04test");
409
410 let mut tx = FramedWrite::new(Vec::new(), CoreNameEncoder);
411
412 trace!("sending `foo`");
413 tx.send("foo").await.expect("failed to send `foo`");
414
415 trace!("sending ``");
416 tx.send(String::default()).await.expect("failed to send ``");
417
418 trace!("sending `test`");
419 tx.send(&&&&&&"test").await.expect("failed to send `test`");
420
421 trace!("sending `bar`");
422 tx.send(Arc::from("bar"))
423 .await
424 .expect("failed to send `bar`");
425
426 trace!("sending `ƒ𐍈Ő`");
427 tx.send(&&String::from("ƒ𐍈Ő"))
428 .await
429 .expect("failed to send `ƒ𐍈Ő`");
430
431 trace!("sending `baz`");
432 tx.send(&&&Arc::from("baz"))
433 .await
434 .expect("failed to send `baz`");
435
436 let tx = tx.into_inner();
437 assert_eq!(
438 tx,
439 concat!("\x03foo", "\0", "\x04test", "\x03bar", "\x08ƒ𐍈Ő", "\x03baz").as_bytes()
440 );
441 let mut rx = FramedRead::new(tx.as_slice(), CoreNameDecoder::default());
442
443 trace!("reading `foo`");
444 let s = rx.try_next().await.expect("failed to get `foo`");
445 assert_eq!(s.as_deref(), Some("foo"));
446
447 trace!("reading ``");
448 let s = rx.try_next().await.expect("failed to get ``");
449 assert_eq!(s.as_deref(), Some(""));
450
451 trace!("reading `test`");
452 let s = rx.try_next().await.expect("failed to get `test`");
453 assert_eq!(s.as_deref(), Some("test"));
454
455 trace!("reading `bar`");
456 let s = rx.try_next().await.expect("failed to get `bar`");
457 assert_eq!(s.as_deref(), Some("bar"));
458
459 trace!("reading `ƒ𐍈Ő`");
460 let s = rx.try_next().await.expect("failed to get `ƒ𐍈Ő`");
461 assert_eq!(s.as_deref(), Some("ƒ𐍈Ő"));
462
463 trace!("reading `baz`");
464 let s = rx.try_next().await.expect("failed to get `baz`");
465 assert_eq!(s.as_deref(), Some("baz"));
466
467 let s = rx.try_next().await.expect("failed to get EOF");
468 assert_eq!(s, None);
469 }
470
471 #[test_log::test(tokio::test)]
472 async fn vec() {
473 let mut tx = FramedWrite::new(Vec::new(), CoreVecEncoder(CoreNameEncoder));
474
475 trace!("sending [`foo`, ``, `test`, `bar`, `ƒ𐍈Ő`, `baz`]");
476 tx.send(&["foo", "", "test", "bar", "ƒ𐍈Ő", "baz"])
477 .await
478 .expect("failed to send [`foo`, ``, `test`, `bar`, `ƒ𐍈Ő`, `baz`]");
479
480 trace!("sending [``]");
481 tx.send(&[""; 0]).await.expect("failed to send []");
482
483 trace!("sending [`test`]");
484 tx.send(&["test"]).await.expect("failed to send [`test`]");
485
486 trace!("sending [``]");
487 tx.send(&[""; 0]).await.expect("failed to send []");
488
489 let tx = tx.into_inner();
490 assert_eq!(
491 tx,
492 concat!(
493 concat!(
494 "\x06",
495 concat!("\x03foo", "\0", "\x04test", "\x03bar", "\x08ƒ𐍈Ő", "\x03baz")
496 ),
497 "\0",
498 concat!("\x01", "\x04test"),
499 "\0"
500 )
501 .as_bytes()
502 );
503 let mut rx = FramedRead::new(tx.as_slice(), CoreVecDecoder::<CoreNameDecoder>::default());
504
505 trace!("reading [`foo`, ``, `test`, `bar`, `baz`]");
506 let s = rx
507 .try_next()
508 .await
509 .expect("failed to get [`foo`, ``, `test`, `bar`, `baz`]");
510 assert_eq!(
511 s.as_deref(),
512 Some(
513 [
514 "foo".to_string(),
515 String::new(),
516 "test".to_string(),
517 "bar".to_string(),
518 "ƒ𐍈Ő".to_string(),
519 "baz".to_string()
520 ]
521 .as_slice()
522 )
523 );
524
525 trace!("reading []");
526 let s = rx.try_next().await.expect("failed to get []");
527 assert_eq!(s.as_deref(), Some([].as_slice()));
528
529 trace!("reading [`test`]");
530 let s = rx.try_next().await.expect("failed to get [`test`]");
531 assert_eq!(s.as_deref(), Some(["test".to_string()].as_slice()));
532
533 trace!("reading []");
534 let s = rx.try_next().await.expect("failed to get []");
535 assert_eq!(s.as_deref(), Some([].as_slice()));
536
537 let s = rx.try_next().await.expect("failed to get EOF");
538 assert_eq!(s, None);
539 }
540}