Skip to main content

rpc_runtime_transport/
lib.rs

1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
9use rpc_runtime_codec_msgpack::{
10    CodecError, CodecLimits, DEFAULT_MAX_MESSAGE_SIZE, decode_envelope, encode_envelope,
11};
12use rpc_runtime_core::Envelope;
13use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17pub type TransportFuture<'a, T> =
18    Pin<Box<dyn Future<Output = Result<T, TransportError>> + Send + 'a>>;
19
20pub type HostBridgeSendFuture =
21    Pin<Box<dyn Future<Output = Result<(), TransportError>> + Send + 'static>>;
22
23pub type AddonSendFuture = HostBridgeSendFuture;
24pub type AddonConfig = HostBridgeConfig;
25pub type AddonConnection = HostBridgeConnection;
26pub type AddonEndpoint = HostBridgeEndpoint;
27pub type AddonFrameSink = dyn HostBridgeFrameSink;
28
29#[derive(Debug, Error)]
30pub enum TransportError {
31    #[error("transport I/O error: {0}")]
32    Io(#[from] io::Error),
33    #[error("transport protocol error: {0}")]
34    Runtime(RuntimeError),
35}
36
37impl TransportError {
38    pub fn runtime(code: RuntimeErrorCode, message: impl Into<String>) -> Self {
39        Self::Runtime(RuntimeError::protocol(code, message))
40    }
41}
42
43impl From<CodecError> for TransportError {
44    fn from(value: CodecError) -> Self {
45        Self::Runtime(value.into_runtime_error())
46    }
47}
48
49pub fn encode_host_bridge_frame_base64(frame: impl AsRef<[u8]>) -> String {
50    BASE64_STANDARD.encode(frame)
51}
52
53pub trait HostBridgeFrameSink: Send + Sync {
54    fn send_frame(&self, frame: Vec<u8>) -> HostBridgeSendFuture;
55}
56
57impl<F, Fut> HostBridgeFrameSink for F
58where
59    F: Send + Sync + 'static + Fn(Vec<u8>) -> Fut,
60    Fut: Future<Output = Result<(), TransportError>> + Send + 'static,
61{
62    fn send_frame(&self, frame: Vec<u8>) -> HostBridgeSendFuture {
63        Box::pin(self(frame))
64    }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct HostBridgeConfig {
69    pub max_frame_size: usize,
70    pub inbound_buffer: usize,
71}
72
73impl Default for HostBridgeConfig {
74    fn default() -> Self {
75        Self {
76            max_frame_size: DEFAULT_MAX_MESSAGE_SIZE,
77            inbound_buffer: 128,
78        }
79    }
80}
81
82#[derive(Clone)]
83pub struct HostBridgeEndpoint {
84    inbound: mpsc::Sender<Option<Envelope>>,
85    config: HostBridgeConfig,
86}
87
88impl HostBridgeEndpoint {
89    pub async fn receive_client_frame(
90        &self,
91        frame: impl AsRef<[u8]>,
92    ) -> Result<(), TransportError> {
93        let frame = frame.as_ref();
94        if frame.len() > self.config.max_frame_size {
95            return Err(TransportError::runtime(
96                RuntimeErrorCode::InvalidEnvelope,
97                format!(
98                    "host bridge frame size {} exceeds limit {}",
99                    frame.len(),
100                    self.config.max_frame_size
101                ),
102            ));
103        }
104        let envelope = decode_envelope(
105            frame,
106            CodecLimits {
107                max_message_size: self.config.max_frame_size,
108            },
109        )?;
110        self.send_inbound(Some(envelope)).await
111    }
112
113    pub async fn receive_client_frame_base64(&self, frame: &str) -> Result<(), TransportError> {
114        let bytes = BASE64_STANDARD.decode(frame).map_err(|err| {
115            TransportError::runtime(
116                RuntimeErrorCode::PayloadDecodeFailed,
117                format!("failed to decode base64 host bridge frame: {err}"),
118            )
119        })?;
120        self.receive_client_frame(bytes).await
121    }
122
123    pub async fn close_client_input(&self) -> Result<(), TransportError> {
124        self.send_inbound(None).await
125    }
126
127    async fn send_inbound(&self, envelope: Option<Envelope>) -> Result<(), TransportError> {
128        self.inbound.send(envelope).await.map_err(|_| {
129            TransportError::Io(io::Error::new(
130                io::ErrorKind::BrokenPipe,
131                "host bridge connection is closed",
132            ))
133        })
134    }
135}
136
137pub struct HostBridgeConnection {
138    inner: RpcConnection,
139}
140
141impl HostBridgeConnection {
142    pub fn new<S>(config: HostBridgeConfig, sink: S) -> (Self, HostBridgeEndpoint)
143    where
144        S: HostBridgeFrameSink + 'static,
145    {
146        let buffer = config.inbound_buffer.max(1);
147        let (inbound, receiver) = mpsc::channel(buffer);
148        let sink = Arc::new(sink);
149        let connection = RpcConnection::new(
150            RpcSender::new(Arc::new(HostBridgeWriter { sink, config })),
151            RpcReceiver::new(Box::new(HostBridgeReader { receiver })),
152        );
153        (
154            Self { inner: connection },
155            HostBridgeEndpoint { inbound, config },
156        )
157    }
158
159    pub fn into_connection(self) -> RpcConnection {
160        self.inner
161    }
162}
163
164impl From<HostBridgeConnection> for RpcConnection {
165    fn from(value: HostBridgeConnection) -> Self {
166        value.into_connection()
167    }
168}
169
170struct HostBridgeWriter {
171    sink: Arc<dyn HostBridgeFrameSink>,
172    config: HostBridgeConfig,
173}
174
175impl EnvelopeWriter for HostBridgeWriter {
176    fn send_envelope<'a>(&'a self, envelope: &'a Envelope) -> TransportFuture<'a, ()> {
177        Box::pin(async move {
178            let frame = encode_envelope(envelope)?;
179            if frame.len() > self.config.max_frame_size {
180                return Err(TransportError::runtime(
181                    RuntimeErrorCode::InvalidEnvelope,
182                    format!(
183                        "host bridge frame size {} exceeds limit {}",
184                        frame.len(),
185                        self.config.max_frame_size
186                    ),
187                ));
188            }
189            self.sink.send_frame(frame).await
190        })
191    }
192
193    fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()> {
194        Box::pin(async { Ok(()) })
195    }
196}
197
198struct HostBridgeReader {
199    receiver: mpsc::Receiver<Option<Envelope>>,
200}
201
202impl EnvelopeReader for HostBridgeReader {
203    fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>> {
204        Box::pin(async move { Ok(self.receiver.recv().await.flatten()) })
205    }
206}
207
208pub trait EnvelopeWriter: Send + Sync {
209    fn send_envelope<'a>(&'a self, envelope: &'a Envelope) -> TransportFuture<'a, ()>;
210
211    fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()>;
212}
213
214pub trait EnvelopeReader: Send {
215    fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>>;
216}
217
218pub trait RpcListener: Send {
219    fn accept<'a>(&'a mut self) -> TransportFuture<'a, RpcConnection>;
220
221    fn set_connection_scope(&mut self, _: ConnectionScope) {}
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub enum ConnectionScope {
226    LocalOnly,
227    RemoteAllowed,
228}
229
230impl Default for ConnectionScope {
231    fn default() -> Self {
232        Self::LocalOnly
233    }
234}
235
236pub fn is_local_socket_addr(addr: &SocketAddr) -> bool {
237    addr.ip().is_loopback()
238}
239
240#[derive(Clone)]
241pub struct RpcSender {
242    inner: Arc<dyn EnvelopeWriter>,
243}
244
245impl RpcSender {
246    pub fn new(inner: Arc<dyn EnvelopeWriter>) -> Self {
247        Self { inner }
248    }
249
250    pub async fn send_envelope(&self, envelope: &Envelope) -> Result<(), TransportError> {
251        self.inner.send_envelope(envelope).await
252    }
253
254    pub async fn shutdown(&self) -> Result<(), TransportError> {
255        self.inner.shutdown().await
256    }
257}
258
259pub struct RpcReceiver {
260    inner: Box<dyn EnvelopeReader>,
261}
262
263impl RpcReceiver {
264    pub fn new(inner: Box<dyn EnvelopeReader>) -> Self {
265        Self { inner }
266    }
267
268    pub async fn recv_envelope(&mut self) -> Result<Option<Envelope>, TransportError> {
269        self.inner.recv_envelope().await
270    }
271}
272
273pub struct RpcConnection {
274    sender: RpcSender,
275    receiver: RpcReceiver,
276}
277
278impl RpcConnection {
279    pub fn new(sender: RpcSender, receiver: RpcReceiver) -> Self {
280        Self { sender, receiver }
281    }
282
283    pub fn split(self) -> (RpcSender, RpcReceiver) {
284        (self.sender, self.receiver)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use rpc_runtime_core::{CapabilityFlags, Hello, HelloAck, RUNTIME_PROTOCOL_VERSION, Role};
292
293    fn hello_envelope() -> Envelope {
294        Envelope::Hello(Hello {
295            protocol_version: RUNTIME_PROTOCOL_VERSION,
296            role: Role::Client,
297            capability_bits: CapabilityFlags::GOODBYE,
298            max_message_size: DEFAULT_MAX_MESSAGE_SIZE as u64,
299            options: Vec::new(),
300        })
301    }
302
303    #[tokio::test]
304    async fn host_bridge_receives_bytes_frame() {
305        let (connection, endpoint) =
306            HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
307        let (_sender, mut receiver) = connection.into_connection().split();
308        let frame = encode_envelope(&hello_envelope()).expect("encode");
309
310        endpoint
311            .receive_client_frame(&frame)
312            .await
313            .expect("receive frame");
314
315        assert_eq!(
316            receiver.recv_envelope().await.expect("read envelope"),
317            Some(hello_envelope())
318        );
319    }
320
321    #[tokio::test]
322    async fn host_bridge_receives_base64_frame() {
323        let (connection, endpoint) =
324            HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
325        let (_sender, mut receiver) = connection.into_connection().split();
326        let frame = encode_envelope(&hello_envelope()).expect("encode");
327        let encoded = encode_host_bridge_frame_base64(&frame);
328
329        endpoint
330            .receive_client_frame_base64(&encoded)
331            .await
332            .expect("receive frame");
333
334        assert_eq!(
335            receiver.recv_envelope().await.expect("read envelope"),
336            Some(hello_envelope())
337        );
338    }
339
340    #[tokio::test]
341    async fn host_bridge_sends_encoded_frame_to_sink() {
342        let (tx, mut rx) = mpsc::unbounded_channel();
343        let (connection, _endpoint) =
344            HostBridgeConnection::new(HostBridgeConfig::default(), move |frame| {
345                let tx = tx.clone();
346                async move {
347                    tx.send(frame).map_err(|_| {
348                        TransportError::Io(io::Error::new(
349                            io::ErrorKind::BrokenPipe,
350                            "test frame sink closed",
351                        ))
352                    })
353                }
354            });
355        let (sender, _receiver) = connection.into_connection().split();
356        let envelope = Envelope::HelloAck(HelloAck {
357            protocol_version: RUNTIME_PROTOCOL_VERSION,
358            accepted_capability_bits: CapabilityFlags::GOODBYE,
359            max_message_size: DEFAULT_MAX_MESSAGE_SIZE as u64,
360            options: Vec::new(),
361        });
362
363        sender
364            .send_envelope(&envelope)
365            .await
366            .expect("send envelope");
367        let frame = rx.recv().await.expect("sink frame");
368
369        assert_eq!(
370            decode_envelope(&frame, CodecLimits::default()).expect("decode"),
371            envelope
372        );
373    }
374
375    #[tokio::test]
376    async fn host_bridge_close_client_input_returns_eof() {
377        let (connection, endpoint) =
378            HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
379        let (_sender, mut receiver) = connection.into_connection().split();
380
381        endpoint
382            .close_client_input()
383            .await
384            .expect("close client input");
385
386        assert_eq!(receiver.recv_envelope().await.expect("read eof"), None);
387    }
388
389    #[tokio::test]
390    async fn host_bridge_rejects_oversized_inbound_frame() {
391        let (_connection, endpoint) = HostBridgeConnection::new(
392            HostBridgeConfig {
393                max_frame_size: 1,
394                inbound_buffer: 1,
395            },
396            |_frame| async { Ok(()) },
397        );
398        let frame = encode_envelope(&hello_envelope()).expect("encode");
399
400        let err = endpoint
401            .receive_client_frame(&frame)
402            .await
403            .expect_err("oversized frame must fail");
404
405        match err {
406            TransportError::Runtime(error) => {
407                assert_eq!(error.code, RuntimeErrorCode::InvalidEnvelope);
408            }
409            TransportError::Io(error) => panic!("expected runtime error, got I/O error: {error}"),
410        }
411    }
412
413    #[tokio::test]
414    async fn host_bridge_rejects_invalid_base64_frame() {
415        let (_connection, endpoint) =
416            HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
417
418        let err = endpoint
419            .receive_client_frame_base64("not base64!")
420            .await
421            .expect_err("invalid base64 must fail");
422
423        match err {
424            TransportError::Runtime(error) => {
425                assert_eq!(error.code, RuntimeErrorCode::PayloadDecodeFailed);
426            }
427            TransportError::Io(error) => panic!("expected runtime error, got I/O error: {error}"),
428        }
429    }
430}