1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
9use rpc_runtime_codec_msgpack::{
10 CodecError, CodecLimits, DEFAULT_MAX_MESSAGE_SIZE, decode_envelope, encode_envelope,
11};
12use rpc_runtime_core::Envelope;
13use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17pub type TransportFuture<'a, T> =
18 Pin<Box<dyn Future<Output = Result<T, TransportError>> + Send + 'a>>;
19
20pub type HostBridgeSendFuture =
21 Pin<Box<dyn Future<Output = Result<(), TransportError>> + Send + 'static>>;
22
23pub type AddonSendFuture = HostBridgeSendFuture;
24pub type AddonConfig = HostBridgeConfig;
25pub type AddonConnection = HostBridgeConnection;
26pub type AddonEndpoint = HostBridgeEndpoint;
27pub type AddonFrameSink = dyn HostBridgeFrameSink;
28
29#[derive(Debug, Error)]
30pub enum TransportError {
31 #[error("transport I/O error: {0}")]
32 Io(#[from] io::Error),
33 #[error("transport protocol error: {0}")]
34 Runtime(RuntimeError),
35}
36
37impl TransportError {
38 pub fn runtime(code: RuntimeErrorCode, message: impl Into<String>) -> Self {
39 Self::Runtime(RuntimeError::protocol(code, message))
40 }
41}
42
43impl From<CodecError> for TransportError {
44 fn from(value: CodecError) -> Self {
45 Self::Runtime(value.into_runtime_error())
46 }
47}
48
49pub fn encode_host_bridge_frame_base64(frame: impl AsRef<[u8]>) -> String {
50 BASE64_STANDARD.encode(frame)
51}
52
53pub trait HostBridgeFrameSink: Send + Sync {
54 fn send_frame(&self, frame: Vec<u8>) -> HostBridgeSendFuture;
55}
56
57impl<F, Fut> HostBridgeFrameSink for F
58where
59 F: Send + Sync + 'static + Fn(Vec<u8>) -> Fut,
60 Fut: Future<Output = Result<(), TransportError>> + Send + 'static,
61{
62 fn send_frame(&self, frame: Vec<u8>) -> HostBridgeSendFuture {
63 Box::pin(self(frame))
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct HostBridgeConfig {
69 pub max_frame_size: usize,
70 pub inbound_buffer: usize,
71}
72
73impl Default for HostBridgeConfig {
74 fn default() -> Self {
75 Self {
76 max_frame_size: DEFAULT_MAX_MESSAGE_SIZE,
77 inbound_buffer: 128,
78 }
79 }
80}
81
82#[derive(Clone)]
83pub struct HostBridgeEndpoint {
84 inbound: mpsc::Sender<Option<Envelope>>,
85 config: HostBridgeConfig,
86}
87
88impl HostBridgeEndpoint {
89 pub async fn receive_client_frame(
90 &self,
91 frame: impl AsRef<[u8]>,
92 ) -> Result<(), TransportError> {
93 let frame = frame.as_ref();
94 if frame.len() > self.config.max_frame_size {
95 return Err(TransportError::runtime(
96 RuntimeErrorCode::InvalidEnvelope,
97 format!(
98 "host bridge frame size {} exceeds limit {}",
99 frame.len(),
100 self.config.max_frame_size
101 ),
102 ));
103 }
104 let envelope = decode_envelope(
105 frame,
106 CodecLimits {
107 max_message_size: self.config.max_frame_size,
108 },
109 )?;
110 self.send_inbound(Some(envelope)).await
111 }
112
113 pub async fn receive_client_frame_base64(&self, frame: &str) -> Result<(), TransportError> {
114 let bytes = BASE64_STANDARD.decode(frame).map_err(|err| {
115 TransportError::runtime(
116 RuntimeErrorCode::PayloadDecodeFailed,
117 format!("failed to decode base64 host bridge frame: {err}"),
118 )
119 })?;
120 self.receive_client_frame(bytes).await
121 }
122
123 pub async fn close_client_input(&self) -> Result<(), TransportError> {
124 self.send_inbound(None).await
125 }
126
127 async fn send_inbound(&self, envelope: Option<Envelope>) -> Result<(), TransportError> {
128 self.inbound.send(envelope).await.map_err(|_| {
129 TransportError::Io(io::Error::new(
130 io::ErrorKind::BrokenPipe,
131 "host bridge connection is closed",
132 ))
133 })
134 }
135}
136
137pub struct HostBridgeConnection {
138 inner: RpcConnection,
139}
140
141impl HostBridgeConnection {
142 pub fn new<S>(config: HostBridgeConfig, sink: S) -> (Self, HostBridgeEndpoint)
143 where
144 S: HostBridgeFrameSink + 'static,
145 {
146 let buffer = config.inbound_buffer.max(1);
147 let (inbound, receiver) = mpsc::channel(buffer);
148 let sink = Arc::new(sink);
149 let connection = RpcConnection::new(
150 RpcSender::new(Arc::new(HostBridgeWriter { sink, config })),
151 RpcReceiver::new(Box::new(HostBridgeReader { receiver })),
152 );
153 (
154 Self { inner: connection },
155 HostBridgeEndpoint { inbound, config },
156 )
157 }
158
159 pub fn into_connection(self) -> RpcConnection {
160 self.inner
161 }
162}
163
164impl From<HostBridgeConnection> for RpcConnection {
165 fn from(value: HostBridgeConnection) -> Self {
166 value.into_connection()
167 }
168}
169
170struct HostBridgeWriter {
171 sink: Arc<dyn HostBridgeFrameSink>,
172 config: HostBridgeConfig,
173}
174
175impl EnvelopeWriter for HostBridgeWriter {
176 fn send_envelope<'a>(&'a self, envelope: &'a Envelope) -> TransportFuture<'a, ()> {
177 Box::pin(async move {
178 let frame = encode_envelope(envelope)?;
179 if frame.len() > self.config.max_frame_size {
180 return Err(TransportError::runtime(
181 RuntimeErrorCode::InvalidEnvelope,
182 format!(
183 "host bridge frame size {} exceeds limit {}",
184 frame.len(),
185 self.config.max_frame_size
186 ),
187 ));
188 }
189 self.sink.send_frame(frame).await
190 })
191 }
192
193 fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()> {
194 Box::pin(async { Ok(()) })
195 }
196}
197
198struct HostBridgeReader {
199 receiver: mpsc::Receiver<Option<Envelope>>,
200}
201
202impl EnvelopeReader for HostBridgeReader {
203 fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>> {
204 Box::pin(async move { Ok(self.receiver.recv().await.flatten()) })
205 }
206}
207
208pub trait EnvelopeWriter: Send + Sync {
209 fn send_envelope<'a>(&'a self, envelope: &'a Envelope) -> TransportFuture<'a, ()>;
210
211 fn shutdown<'a>(&'a self) -> TransportFuture<'a, ()>;
212}
213
214pub trait EnvelopeReader: Send {
215 fn recv_envelope<'a>(&'a mut self) -> TransportFuture<'a, Option<Envelope>>;
216}
217
218pub trait RpcListener: Send {
219 fn accept<'a>(&'a mut self) -> TransportFuture<'a, RpcConnection>;
220
221 fn set_connection_scope(&mut self, _: ConnectionScope) {}
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub enum ConnectionScope {
226 LocalOnly,
227 RemoteAllowed,
228}
229
230impl Default for ConnectionScope {
231 fn default() -> Self {
232 Self::LocalOnly
233 }
234}
235
236pub fn is_local_socket_addr(addr: &SocketAddr) -> bool {
237 addr.ip().is_loopback()
238}
239
240#[derive(Clone)]
241pub struct RpcSender {
242 inner: Arc<dyn EnvelopeWriter>,
243}
244
245impl RpcSender {
246 pub fn new(inner: Arc<dyn EnvelopeWriter>) -> Self {
247 Self { inner }
248 }
249
250 pub async fn send_envelope(&self, envelope: &Envelope) -> Result<(), TransportError> {
251 self.inner.send_envelope(envelope).await
252 }
253
254 pub async fn shutdown(&self) -> Result<(), TransportError> {
255 self.inner.shutdown().await
256 }
257}
258
259pub struct RpcReceiver {
260 inner: Box<dyn EnvelopeReader>,
261}
262
263impl RpcReceiver {
264 pub fn new(inner: Box<dyn EnvelopeReader>) -> Self {
265 Self { inner }
266 }
267
268 pub async fn recv_envelope(&mut self) -> Result<Option<Envelope>, TransportError> {
269 self.inner.recv_envelope().await
270 }
271}
272
273pub struct RpcConnection {
274 sender: RpcSender,
275 receiver: RpcReceiver,
276}
277
278impl RpcConnection {
279 pub fn new(sender: RpcSender, receiver: RpcReceiver) -> Self {
280 Self { sender, receiver }
281 }
282
283 pub fn split(self) -> (RpcSender, RpcReceiver) {
284 (self.sender, self.receiver)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use rpc_runtime_core::{CapabilityFlags, Hello, HelloAck, RUNTIME_PROTOCOL_VERSION, Role};
292
293 fn hello_envelope() -> Envelope {
294 Envelope::Hello(Hello {
295 protocol_version: RUNTIME_PROTOCOL_VERSION,
296 role: Role::Client,
297 capability_bits: CapabilityFlags::GOODBYE,
298 max_message_size: DEFAULT_MAX_MESSAGE_SIZE as u64,
299 options: Vec::new(),
300 })
301 }
302
303 #[tokio::test]
304 async fn host_bridge_receives_bytes_frame() {
305 let (connection, endpoint) =
306 HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
307 let (_sender, mut receiver) = connection.into_connection().split();
308 let frame = encode_envelope(&hello_envelope()).expect("encode");
309
310 endpoint
311 .receive_client_frame(&frame)
312 .await
313 .expect("receive frame");
314
315 assert_eq!(
316 receiver.recv_envelope().await.expect("read envelope"),
317 Some(hello_envelope())
318 );
319 }
320
321 #[tokio::test]
322 async fn host_bridge_receives_base64_frame() {
323 let (connection, endpoint) =
324 HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
325 let (_sender, mut receiver) = connection.into_connection().split();
326 let frame = encode_envelope(&hello_envelope()).expect("encode");
327 let encoded = encode_host_bridge_frame_base64(&frame);
328
329 endpoint
330 .receive_client_frame_base64(&encoded)
331 .await
332 .expect("receive frame");
333
334 assert_eq!(
335 receiver.recv_envelope().await.expect("read envelope"),
336 Some(hello_envelope())
337 );
338 }
339
340 #[tokio::test]
341 async fn host_bridge_sends_encoded_frame_to_sink() {
342 let (tx, mut rx) = mpsc::unbounded_channel();
343 let (connection, _endpoint) =
344 HostBridgeConnection::new(HostBridgeConfig::default(), move |frame| {
345 let tx = tx.clone();
346 async move {
347 tx.send(frame).map_err(|_| {
348 TransportError::Io(io::Error::new(
349 io::ErrorKind::BrokenPipe,
350 "test frame sink closed",
351 ))
352 })
353 }
354 });
355 let (sender, _receiver) = connection.into_connection().split();
356 let envelope = Envelope::HelloAck(HelloAck {
357 protocol_version: RUNTIME_PROTOCOL_VERSION,
358 accepted_capability_bits: CapabilityFlags::GOODBYE,
359 max_message_size: DEFAULT_MAX_MESSAGE_SIZE as u64,
360 options: Vec::new(),
361 });
362
363 sender
364 .send_envelope(&envelope)
365 .await
366 .expect("send envelope");
367 let frame = rx.recv().await.expect("sink frame");
368
369 assert_eq!(
370 decode_envelope(&frame, CodecLimits::default()).expect("decode"),
371 envelope
372 );
373 }
374
375 #[tokio::test]
376 async fn host_bridge_close_client_input_returns_eof() {
377 let (connection, endpoint) =
378 HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
379 let (_sender, mut receiver) = connection.into_connection().split();
380
381 endpoint
382 .close_client_input()
383 .await
384 .expect("close client input");
385
386 assert_eq!(receiver.recv_envelope().await.expect("read eof"), None);
387 }
388
389 #[tokio::test]
390 async fn host_bridge_rejects_oversized_inbound_frame() {
391 let (_connection, endpoint) = HostBridgeConnection::new(
392 HostBridgeConfig {
393 max_frame_size: 1,
394 inbound_buffer: 1,
395 },
396 |_frame| async { Ok(()) },
397 );
398 let frame = encode_envelope(&hello_envelope()).expect("encode");
399
400 let err = endpoint
401 .receive_client_frame(&frame)
402 .await
403 .expect_err("oversized frame must fail");
404
405 match err {
406 TransportError::Runtime(error) => {
407 assert_eq!(error.code, RuntimeErrorCode::InvalidEnvelope);
408 }
409 TransportError::Io(error) => panic!("expected runtime error, got I/O error: {error}"),
410 }
411 }
412
413 #[tokio::test]
414 async fn host_bridge_rejects_invalid_base64_frame() {
415 let (_connection, endpoint) =
416 HostBridgeConnection::new(HostBridgeConfig::default(), |_frame| async { Ok(()) });
417
418 let err = endpoint
419 .receive_client_frame_base64("not base64!")
420 .await
421 .expect_err("invalid base64 must fail");
422
423 match err {
424 TransportError::Runtime(error) => {
425 assert_eq!(error.code, RuntimeErrorCode::PayloadDecodeFailed);
426 }
427 TransportError::Io(error) => panic!("expected runtime error, got I/O error: {error}"),
428 }
429 }
430}