1use super::{Location, Topology};
8use crate::identifiers::RoleName;
9use crate::mutex_lock;
10#[cfg(not(target_arch = "wasm32"))]
11use crate::runtime::spawn::spawn;
12use crate::runtime::sync::{mpsc, Mutex};
13use async_trait::async_trait;
14use cfg_if::cfg_if;
15#[cfg(target_arch = "wasm32")]
16use futures::{SinkExt, StreamExt};
17use std::collections::BTreeMap;
18use std::sync::Arc;
19#[cfg(not(target_arch = "wasm32"))]
20use std::sync::{Mutex as StdMutex, OnceLock};
21use thiserror::Error;
22
23#[cfg(not(target_arch = "wasm32"))]
24use tokio::io::{AsyncReadExt, AsyncWriteExt};
25#[cfg(not(target_arch = "wasm32"))]
26use tokio::net::{TcpListener, TcpStream};
27#[cfg(not(target_arch = "wasm32"))]
28use tokio::time::{sleep, Duration};
29
30#[derive(Debug, Error)]
32pub enum TransportError {
33 #[error("connection failed: {0}")]
34 ConnectionFailed(String),
35
36 #[error("send failed: {0}")]
37 SendFailed(String),
38
39 #[error("receive failed: {0}")]
40 ReceiveFailed(String),
41
42 #[error("timeout")]
43 Timeout,
44
45 #[error("channel closed")]
46 ChannelClosed,
47
48 #[error("unknown role: {0}")]
49 UnknownRole(RoleName),
50
51 #[error("transport not ready")]
52 NotReady,
53
54 #[error("IO error: {0}")]
55 IoError(#[from] std::io::Error),
56}
57
58pub type TransportResult<T> = Result<T, TransportError>;
60
61pub trait TransportMessage: Send + Sync + 'static {
63 fn to_bytes(&self) -> Vec<u8>;
65
66 fn from_bytes(bytes: &[u8]) -> Result<Self, String>
68 where
69 Self: Sized;
70}
71
72#[derive(Debug, Clone)]
74pub struct ByteMessage(pub Vec<u8>);
75
76impl TransportMessage for ByteMessage {
77 fn to_bytes(&self) -> Vec<u8> {
78 self.0.clone()
79 }
80
81 fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
82 Ok(ByteMessage(bytes.to_vec()))
83 }
84}
85
86#[async_trait]
88pub trait Transport: Send + Sync + 'static {
89 async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()>;
91
92 async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>>;
94
95 fn is_connected(&self, role: &RoleName) -> bool;
97
98 async fn close(&self) -> TransportResult<()>;
100}
101
102pub struct InMemoryChannelTransport {
107 role: RoleName,
109 senders: Arc<Mutex<BTreeMap<RoleName, mpsc::Sender<Vec<u8>>>>>,
111 receivers: Arc<Mutex<BTreeMap<RoleName, mpsc::Receiver<Vec<u8>>>>>,
113}
114
115impl InMemoryChannelTransport {
116 pub fn new(role: RoleName) -> Self {
118 Self {
119 role,
120 senders: Arc::new(Mutex::new(BTreeMap::new())),
121 receivers: Arc::new(Mutex::new(BTreeMap::new())),
122 }
123 }
124
125 pub async fn connect(&self, other: &InMemoryChannelTransport) {
127 let (tx1, rx1) = mpsc::channel(32);
128 let (tx2, rx2) = mpsc::channel(32);
129
130 mutex_lock!(self.senders).insert(other.role.clone(), tx1);
132 mutex_lock!(other.receivers).insert(self.role.clone(), rx1);
133
134 mutex_lock!(other.senders).insert(self.role.clone(), tx2);
136 mutex_lock!(self.receivers).insert(other.role.clone(), rx2);
137 }
138
139 pub fn role(&self) -> &RoleName {
141 &self.role
142 }
143}
144
145#[async_trait]
146impl Transport for InMemoryChannelTransport {
147 async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()> {
148 cfg_if! {
149 if #[cfg(target_arch = "wasm32")] {
150 let sender = {
152 let senders = mutex_lock!(self.senders);
153 senders
154 .get(to_role)
155 .cloned()
156 .ok_or_else(|| TransportError::UnknownRole(to_role.clone()))?
157 };
158
159 let mut sender = sender;
160 sender
161 .send(message)
162 .await
163 .map_err(|_| TransportError::ChannelClosed)
164 } else {
165 let senders = mutex_lock!(self.senders);
166 let sender = senders
167 .get(to_role)
168 .ok_or_else(|| TransportError::UnknownRole(to_role.clone()))?;
169
170 sender
171 .send(message)
172 .await
173 .map_err(|_| TransportError::ChannelClosed)
174 }
175 }
176 }
177
178 async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>> {
179 cfg_if! {
180 if #[cfg(target_arch = "wasm32")] {
181 let mut receiver = {
183 let mut receivers = mutex_lock!(self.receivers);
184 receivers
185 .remove(from_role)
186 .ok_or_else(|| TransportError::UnknownRole(from_role.clone()))?
187 };
188
189 let result = receiver.next().await;
190
191 {
192 let mut receivers = mutex_lock!(self.receivers);
193 receivers.insert(from_role.clone(), receiver);
194 }
195
196 result.ok_or(TransportError::ChannelClosed)
197 } else {
198 let mut receivers = mutex_lock!(self.receivers);
199 let receiver = receivers
200 .get_mut(from_role)
201 .ok_or_else(|| TransportError::UnknownRole(from_role.clone()))?;
202 receiver.recv().await.ok_or(TransportError::ChannelClosed)
203 }
204 }
205 }
206
207 fn is_connected(&self, _role: &RoleName) -> bool {
208 true
211 }
212
213 async fn close(&self) -> TransportResult<()> {
214 mutex_lock!(self.senders).clear();
215 mutex_lock!(self.receivers).clear();
216 Ok(())
217 }
218}
219
220#[cfg(not(target_arch = "wasm32"))]
221enum TcpListenerState {
222 NotStarted,
223 Started,
224 Failed(String),
225}
226
227#[cfg(not(target_arch = "wasm32"))]
228struct TcpRoleState {
229 role: RoleName,
230 self_endpoint: Option<crate::identifiers::Endpoint>,
231 inbound_senders: BTreeMap<RoleName, mpsc::Sender<Vec<u8>>>,
232 inbound_receivers: Arc<Mutex<BTreeMap<RoleName, mpsc::Receiver<Vec<u8>>>>>,
233 listener_state: Arc<Mutex<TcpListenerState>>,
234}
235
236#[cfg(not(target_arch = "wasm32"))]
237impl TcpRoleState {
238 fn new(
239 role: RoleName,
240 self_endpoint: Option<crate::identifiers::Endpoint>,
241 peer_roles: impl IntoIterator<Item = RoleName>,
242 ) -> Self {
243 let mut inbound_senders = BTreeMap::new();
244 let mut inbound_receivers = BTreeMap::new();
245 for peer in peer_roles {
246 let (tx, rx) = mpsc::channel(32);
247 inbound_senders.insert(peer.clone(), tx);
248 inbound_receivers.insert(peer, rx);
249 }
250 Self {
251 role,
252 self_endpoint,
253 inbound_senders,
254 inbound_receivers: Arc::new(Mutex::new(inbound_receivers)),
255 listener_state: Arc::new(Mutex::new(TcpListenerState::NotStarted)),
256 }
257 }
258
259 async fn ensure_started(self: &Arc<Self>) -> TransportResult<()> {
260 let mut state = mutex_lock!(self.listener_state);
261 match &*state {
262 TcpListenerState::Started => return Ok(()),
263 TcpListenerState::Failed(message) => {
264 return Err(TransportError::ConnectionFailed(message.clone()));
265 }
266 TcpListenerState::NotStarted => {}
267 }
268
269 let Some(endpoint) = self.self_endpoint.clone() else {
270 *state = TcpListenerState::Started;
271 return Ok(());
272 };
273
274 let listener = TcpListener::bind(endpoint.as_str()).await.map_err(|err| {
275 let message = format!(
276 "failed to bind {} for role {}: {}",
277 endpoint, self.role, err
278 );
279 *state = TcpListenerState::Failed(message.clone());
280 TransportError::ConnectionFailed(message)
281 })?;
282 let role_state = Arc::clone(self);
283 spawn(async move {
284 role_state.accept_loop(listener).await;
285 });
286 *state = TcpListenerState::Started;
287 Ok(())
288 }
289
290 async fn accept_loop(self: Arc<Self>, listener: TcpListener) {
291 loop {
292 let Ok((socket, _)) = listener.accept().await else {
293 break;
294 };
295 let role_state = Arc::clone(&self);
296 spawn(async move {
297 let _ = role_state.handle_socket(socket).await;
298 });
299 }
300 }
301
302 async fn handle_socket(&self, mut socket: TcpStream) -> TransportResult<()> {
303 let role_len = socket.read_u32().await? as usize;
304 let mut role_buf = vec![0_u8; role_len];
305 socket.read_exact(&mut role_buf).await?;
306 let from_role = String::from_utf8(role_buf).map_err(|err| {
307 TransportError::ReceiveFailed(format!("invalid sender header: {err}"))
308 })?;
309 let payload_len = socket.read_u32().await? as usize;
310 let mut payload = vec![0_u8; payload_len];
311 socket.read_exact(&mut payload).await?;
312 let sender_role = RoleName::new(from_role.clone()).map_err(|err| {
313 TransportError::ReceiveFailed(format!("invalid sender role `{from_role}`: {err}"))
314 })?;
315 let sender = self
316 .inbound_senders
317 .get(&sender_role)
318 .cloned()
319 .ok_or_else(|| {
320 TransportError::ReceiveFailed(format!(
321 "sender role `{sender_role}` is not configured for {}",
322 self.role
323 ))
324 })?;
325 sender
326 .send(payload)
327 .await
328 .map_err(|_| TransportError::ChannelClosed)
329 }
330
331 async fn recv_from(&self, from_role: &RoleName) -> TransportResult<Vec<u8>> {
332 let mut receivers = mutex_lock!(self.inbound_receivers);
333 let receiver = receivers
334 .get_mut(from_role)
335 .ok_or_else(|| TransportError::UnknownRole(from_role.clone()))?;
336 receiver.recv().await.ok_or(TransportError::ChannelClosed)
337 }
338}
339
340#[cfg(not(target_arch = "wasm32"))]
341type SharedTcpRegistry = BTreeMap<String, Arc<TcpRoleState>>;
342
343#[cfg(not(target_arch = "wasm32"))]
344fn shared_tcp_registry() -> &'static StdMutex<SharedTcpRegistry> {
345 static REGISTRY: OnceLock<StdMutex<SharedTcpRegistry>> = OnceLock::new();
346 REGISTRY.get_or_init(|| StdMutex::new(BTreeMap::new()))
347}
348
349#[cfg(not(target_arch = "wasm32"))]
350fn tcp_role_registry_key(topology_signature: &str, role: &RoleName) -> String {
351 format!("{topology_signature}|role:{role}")
352}
353
354#[cfg(not(target_arch = "wasm32"))]
355fn shared_tcp_role_state(
356 topology: &Topology,
357 topology_signature: &str,
358 role: &RoleName,
359) -> TransportResult<Arc<TcpRoleState>> {
360 let key = tcp_role_registry_key(topology_signature, role);
361 let mut registry = shared_tcp_registry()
362 .lock()
363 .unwrap_or_else(|poisoned| poisoned.into_inner());
364 if let Some(existing) = registry.get(&key) {
365 return Ok(Arc::clone(existing));
366 }
367
368 let self_endpoint = match topology.get_location(role) {
369 Ok(Location::Remote(endpoint)) => Some(endpoint),
370 Ok(Location::Local | Location::Colocated(_)) => None,
371 Err(_) => return Err(TransportError::UnknownRole(role.clone())),
372 };
373 let peer_roles = topology
374 .locations
375 .keys()
376 .filter(|peer| *peer != role)
377 .cloned();
378 let state = Arc::new(TcpRoleState::new(role.clone(), self_endpoint, peer_roles));
379 registry.insert(key, Arc::clone(&state));
380 Ok(state)
381}
382
383#[cfg(not(target_arch = "wasm32"))]
384async fn connect_with_retry(endpoint: &crate::identifiers::Endpoint) -> TransportResult<TcpStream> {
385 let mut attempts = 0_u8;
386 loop {
387 match TcpStream::connect(endpoint.as_str()).await {
388 Ok(stream) => return Ok(stream),
389 Err(err) if attempts < 10 => {
390 attempts = attempts.saturating_add(1);
391 if err.kind() != std::io::ErrorKind::ConnectionRefused {
392 return Err(TransportError::ConnectionFailed(err.to_string()));
393 }
394 sleep(Duration::from_millis(10)).await;
395 }
396 Err(err) => return Err(TransportError::ConnectionFailed(err.to_string())),
397 }
398 }
399}
400
401#[cfg(not(target_arch = "wasm32"))]
402struct TcpPeerTransport {
403 state: Arc<TcpRoleState>,
404 peer_role: RoleName,
405 peer_endpoint: Option<crate::identifiers::Endpoint>,
406}
407
408#[cfg(not(target_arch = "wasm32"))]
409#[async_trait]
410impl Transport for TcpPeerTransport {
411 async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()> {
412 if to_role != &self.peer_role {
413 return Err(TransportError::UnknownRole(to_role.clone()));
414 }
415 let endpoint = self.peer_endpoint.clone().ok_or_else(|| {
416 TransportError::ConnectionFailed(format!(
417 "role {} has no remote endpoint configured for peer {}",
418 self.state.role, self.peer_role
419 ))
420 })?;
421 let mut stream = connect_with_retry(&endpoint).await?;
422 let role_bytes = self.state.role.to_string().into_bytes();
423 stream.write_u32(role_bytes.len() as u32).await?;
424 stream.write_all(&role_bytes).await?;
425 stream.write_u32(message.len() as u32).await?;
426 stream.write_all(&message).await?;
427 stream.shutdown().await?;
428 Ok(())
429 }
430
431 async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>> {
432 if from_role != &self.peer_role {
433 return Err(TransportError::UnknownRole(from_role.clone()));
434 }
435 self.state.recv_from(from_role).await
436 }
437
438 fn is_connected(&self, role: &RoleName) -> bool {
439 role == &self.peer_role
440 }
441
442 async fn close(&self) -> TransportResult<()> {
443 Ok(())
444 }
445}
446
447#[cfg(not(target_arch = "wasm32"))]
448struct TcpRoleTransport {
449 state: Arc<TcpRoleState>,
450 peer_endpoints: BTreeMap<RoleName, Option<crate::identifiers::Endpoint>>,
451}
452
453#[cfg(not(target_arch = "wasm32"))]
454#[async_trait]
455impl Transport for TcpRoleTransport {
456 async fn send(&self, to_role: &RoleName, message: Vec<u8>) -> TransportResult<()> {
457 self.state.ensure_started().await?;
458 let endpoint = self
459 .peer_endpoints
460 .get(to_role)
461 .cloned()
462 .flatten()
463 .ok_or_else(|| {
464 TransportError::ConnectionFailed(format!(
465 "role {} has no remote endpoint configured for peer {}",
466 self.state.role, to_role
467 ))
468 })?;
469 let mut stream = connect_with_retry(&endpoint).await?;
470 let role_bytes = self.state.role.to_string().into_bytes();
471 stream.write_u32(role_bytes.len() as u32).await?;
472 stream.write_all(&role_bytes).await?;
473 stream.write_u32(message.len() as u32).await?;
474 stream.write_all(&message).await?;
475 stream.shutdown().await?;
476 Ok(())
477 }
478
479 async fn recv(&self, from_role: &RoleName) -> TransportResult<Vec<u8>> {
480 self.state.ensure_started().await?;
481 self.state.recv_from(from_role).await
482 }
483
484 fn is_connected(&self, role: &RoleName) -> bool {
485 self.peer_endpoints.contains_key(role)
486 }
487
488 async fn close(&self) -> TransportResult<()> {
489 Ok(())
490 }
491}
492
493#[cfg(not(target_arch = "wasm32"))]
494pub(crate) async fn create_peer_transport(
495 topology: &Topology,
496 topology_signature: &str,
497 role: &RoleName,
498 peer: &RoleName,
499) -> TransportResult<Box<dyn Transport>> {
500 topology
501 .region_for_role(role)
502 .map_err(TransportError::ConnectionFailed)?;
503 topology
504 .region_for_role(peer)
505 .map_err(TransportError::ConnectionFailed)?;
506 let state = shared_tcp_role_state(topology, topology_signature, role)?;
507 state.ensure_started().await?;
508 let peer_endpoint = match topology.get_location(peer) {
509 Ok(Location::Remote(endpoint)) => Some(endpoint),
510 Ok(Location::Local | Location::Colocated(_)) => None,
511 Err(_) => return Err(TransportError::UnknownRole(peer.clone())),
512 };
513 Ok(Box::new(TcpPeerTransport {
514 state,
515 peer_role: peer.clone(),
516 peer_endpoint,
517 }))
518}
519
520pub struct TransportFactory;
522
523impl TransportFactory {
524 pub fn create(topology: &Topology, role: &RoleName) -> TransportResult<Box<dyn Transport>> {
526 let has_remote_participants = topology
527 .locations
528 .values()
529 .any(|location| matches!(location, Location::Remote(_)));
530 if has_remote_participants {
531 #[cfg(target_arch = "wasm32")]
532 {
533 let _ = role;
534 Err(TransportError::NotReady)
535 }
536 #[cfg(not(target_arch = "wasm32"))]
537 {
538 topology
539 .region_for_role(role)
540 .map_err(TransportError::ConnectionFailed)?;
541 let state = shared_tcp_role_state(topology, "transport_factory", role)?;
542 let warm_state = Arc::clone(&state);
543 spawn(async move {
544 let _ = warm_state.ensure_started().await;
545 });
546 let peer_endpoints = topology
547 .locations
548 .iter()
549 .filter(|(peer, _)| *peer != role)
550 .map(|(peer, location)| {
551 let _ = topology
552 .region_for_role(peer)
553 .map_err(TransportError::ConnectionFailed)?;
554 let endpoint = match location {
555 Location::Remote(endpoint) => Some(endpoint.clone()),
556 Location::Local | Location::Colocated(_) => None,
557 };
558 Ok((peer.clone(), endpoint))
559 })
560 .collect::<TransportResult<BTreeMap<_, _>>>()?;
561 Ok(Box::new(TcpRoleTransport {
562 state,
563 peer_endpoints,
564 }))
565 }
566 } else {
567 Ok(Box::new(InMemoryChannelTransport::new(role.clone())))
568 }
569 }
570
571 pub fn transport_for_location(
573 _from_role: &RoleName,
574 to_role: &RoleName,
575 topology: &Topology,
576 ) -> Result<TransportType, super::TopologyError> {
577 match topology.get_location(to_role)? {
578 Location::Local => Ok(TransportType::InMemory),
579 Location::Colocated(_) => Ok(TransportType::SharedMemory),
580 Location::Remote(_) => Ok(TransportType::Tcp),
581 }
582 }
583}
584
585#[derive(Debug, Clone, Copy, PartialEq, Eq)]
587pub enum TransportType {
588 InMemory,
590 SharedMemory,
592 Tcp,
594 WebSocket,
596}
597
598impl TransportType {
599 pub fn is_local(&self) -> bool {
601 matches!(self, TransportType::InMemory | TransportType::SharedMemory)
602 }
603}
604
605#[cfg(all(test, not(target_arch = "wasm32")))]
606mod tests {
607 use super::*;
608
609 #[tokio::test]
610 async fn test_in_memory_transport() {
611 let alice = InMemoryChannelTransport::new(RoleName::from_static("Alice"));
612 let bob = InMemoryChannelTransport::new(RoleName::from_static("Bob"));
613
614 alice.connect(&bob).await;
615
616 alice
618 .send(&RoleName::from_static("Bob"), b"Hello Bob".to_vec())
619 .await
620 .unwrap();
621
622 let msg = bob.recv(&RoleName::from_static("Alice")).await.unwrap();
624 assert_eq!(msg, b"Hello Bob".to_vec());
625
626 bob.send(&RoleName::from_static("Alice"), b"Hello Alice".to_vec())
628 .await
629 .unwrap();
630
631 let msg = alice.recv(&RoleName::from_static("Bob")).await.unwrap();
633 assert_eq!(msg, b"Hello Alice".to_vec());
634 }
635
636 #[test]
637 fn test_transport_type_for_location() {
638 let topology = Topology::builder()
639 .local_role(RoleName::from_static("Alice"))
640 .remote_role(
641 RoleName::from_static("Bob"),
642 crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
643 )
644 .colocated_role(
645 RoleName::from_static("Carol"),
646 RoleName::from_static("Alice"),
647 )
648 .build();
649
650 assert_eq!(
651 TransportFactory::transport_for_location(
652 &RoleName::from_static("Alice"),
653 &RoleName::from_static("Alice"),
654 &topology
655 )
656 .unwrap(),
657 TransportType::InMemory
658 );
659 assert_eq!(
660 TransportFactory::transport_for_location(
661 &RoleName::from_static("Alice"),
662 &RoleName::from_static("Bob"),
663 &topology
664 )
665 .unwrap(),
666 TransportType::Tcp
667 );
668 assert_eq!(
669 TransportFactory::transport_for_location(
670 &RoleName::from_static("Alice"),
671 &RoleName::from_static("Carol"),
672 &topology
673 )
674 .unwrap(),
675 TransportType::SharedMemory
676 );
677 }
678
679 #[test]
680 fn test_transport_type_is_local() {
681 assert!(TransportType::InMemory.is_local());
682 assert!(TransportType::SharedMemory.is_local());
683 assert!(!TransportType::Tcp.is_local());
684 assert!(!TransportType::WebSocket.is_local());
685 }
686
687 #[tokio::test]
688 async fn test_transport_factory_create_supports_loopback_remote_topologies() {
689 let local_topology = Topology::builder()
690 .local_role(RoleName::from_static("Alice"))
691 .local_role(RoleName::from_static("Bob"))
692 .build();
693 assert!(TransportFactory::create(&local_topology, &RoleName::from_static("Alice")).is_ok());
694
695 let remote_topology = Topology::builder()
696 .remote_role(
697 RoleName::from_static("Alice"),
698 crate::identifiers::Endpoint::new("127.0.0.1:19801").unwrap(),
699 )
700 .remote_role(
701 RoleName::from_static("Bob"),
702 crate::identifiers::Endpoint::new("127.0.0.1:19802").unwrap(),
703 )
704 .build();
705 let alice = TransportFactory::create(&remote_topology, &RoleName::from_static("Alice"))
706 .expect("remote transport for Alice");
707 let bob = TransportFactory::create(&remote_topology, &RoleName::from_static("Bob"))
708 .expect("remote transport for Bob");
709 alice
710 .send(&RoleName::from_static("Bob"), b"hello remote".to_vec())
711 .await
712 .expect("remote send");
713 assert_eq!(
714 bob.recv(&RoleName::from_static("Alice"))
715 .await
716 .expect("remote recv"),
717 b"hello remote".to_vec()
718 );
719 }
720}