tiny_rpc/
io.rs

1use std::{
2    convert::TryInto,
3    mem::size_of,
4    pin::Pin,
5    sync::{
6        atomic::{AtomicU64, Ordering},
7        Arc,
8    },
9};
10
11use bincode::{deserialize, serialize_into, serialized_size};
12use bytes::{BufMut, Bytes, BytesMut};
13use futures::{channel::mpsc, future::ready, Sink, SinkExt, Stream, StreamExt};
14use serde::{Deserialize, Serialize};
15use tokio::io::{split, AsyncRead, AsyncWrite};
16use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
17
18use crate::error::{Error, Result};
19
20#[repr(transparent)]
21#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
22pub struct Id(u64);
23
24impl Id {
25    pub const NULL: Id = Id(0);
26}
27
28impl std::fmt::Display for Id {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "[{:016X}]", self.0)
31    }
32}
33
34#[derive(Clone)]
35pub struct IdGenerator(Arc<AtomicU64>);
36
37impl IdGenerator {
38    pub fn new() -> Self {
39        Self(Arc::new(AtomicU64::new(5)))
40    }
41
42    pub fn next(&self) -> Id {
43        Id(self.0.fetch_add(1, Ordering::SeqCst))
44    }
45}
46
47impl Default for IdGenerator {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53pub struct RpcFrame(Bytes);
54
55impl RpcFrame {
56    pub fn new<T: Serialize>(id: Id, data: T) -> Result<Self> {
57        let cap = size_of::<Id>() + serialized_size(&data)? as usize;
58        let mut buf = BytesMut::with_capacity(cap);
59        buf.put_u64(id.0);
60        let mut writer = buf.writer();
61        serialize_into(&mut writer, &data)?;
62        let buf = writer.into_inner();
63        assert_eq!(cap, buf.capacity());
64        Ok(Self(buf.freeze()))
65    }
66
67    pub fn id(&self) -> Result<Id> {
68        self.0
69            .get(0..size_of::<Id>())
70            .map(|buf| {
71                Id(u64::from_be_bytes(
72                    buf.try_into().expect("infallible: hardcode slice size"),
73                ))
74            })
75            .ok_or(Error::Serialize(None))
76    }
77
78    pub fn data<'a, T: Deserialize<'a>>(&'a self) -> Result<T> {
79        Ok(deserialize(
80            self.0
81                .get(size_of::<Id>()..)
82                .ok_or(Error::Serialize(None))?,
83        )?)
84    }
85}
86
87pub type GenericStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
88pub type GenericSink<T, E> = Pin<Box<dyn Sink<T, Error = E> + Send + Sync + 'static>>;
89
90pub struct Transport {
91    input: GenericStream<Result<RpcFrame>>,
92    output: GenericSink<RpcFrame, Error>,
93}
94
95impl Transport {
96    pub fn from_streamed<T>(io: T) -> Self
97    where
98        T: AsyncRead + AsyncWrite + Send + Sync + 'static,
99    {
100        let (reader, writer) = split(io);
101        Self::from_streamed_pair(reader, writer)
102    }
103
104    pub fn from_streamed_pair<R, W>(reader: R, writer: W) -> Self
105    where
106        R: AsyncRead + Send + Sync + 'static,
107        W: AsyncWrite + Send + Sync + 'static,
108    {
109        let stream = FramedRead::new(reader, LengthDelimitedCodec::default())
110            .map(|buf| buf.map(BytesMut::freeze).map(RpcFrame).map_err(Error::from));
111        let sink = FramedWrite::new(writer, LengthDelimitedCodec::default())
112            .with(|frame: RpcFrame| ready(Ok(frame.0)));
113        Self::from_framed_pair(stream, sink)
114    }
115
116    pub fn from_framed<T>(io: T) -> Self
117    where
118        T: Stream<Item = Result<RpcFrame>> + Sink<RpcFrame, Error = Error> + Send + Sync + 'static,
119    {
120        let (sink, stream) = io.split();
121        Self::from_framed_pair(stream, sink)
122    }
123
124    pub fn from_framed_pair<T, U>(stream: T, sink: U) -> Self
125    where
126        T: Stream<Item = Result<RpcFrame>> + Send + Sync + 'static,
127        U: Sink<RpcFrame, Error = Error> + Send + Sync + 'static,
128    {
129        Self {
130            input: Box::pin(stream),
131            output: Box::pin(sink),
132        }
133    }
134
135    pub fn new_local() -> (Self, Self) {
136        let (tx1, rx1) = mpsc::unbounded::<RpcFrame>();
137        let (tx2, rx2) = mpsc::unbounded::<RpcFrame>();
138
139        let tx1 = tx1.sink_map_err(|_| Error::Io(std::io::ErrorKind::ConnectionAborted.into()));
140        let tx2 = tx2.sink_map_err(|_| Error::Io(std::io::ErrorKind::ConnectionAborted.into()));
141        let rx1 = rx1.map(Ok);
142        let rx2 = rx2.map(Ok);
143
144        let transport_l = Self::from_framed_pair(rx1, tx2);
145        let transport_r = Self::from_framed_pair(rx2, tx1);
146        (transport_l, transport_r)
147    }
148
149    pub fn split(
150        self,
151    ) -> (
152        GenericStream<Result<RpcFrame>>,
153        GenericSink<RpcFrame, Error>,
154    ) {
155        (self.input, self.output)
156    }
157}