1use core::{
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use std::{
9 collections::{HashMap, HashSet},
10 sync::{
11 Arc, Mutex,
12 atomic::{AtomicU64, Ordering},
13 },
14};
15
16use futures::{Future, SinkExt, Stream, StreamExt};
17pub use selium_switchboard_protocol::{AdoptMode, Backpressure, Cardinality, EndpointId};
18use selium_switchboard_protocol::{
19 Direction, EndpointDirections, Message, ProtocolError, WiringEgress, WiringIngress,
20 decode_message, encode_message,
21};
22use thiserror::Error;
23use tracing::{debug, warn};
24
25use selium_userland::{
26 Dependency, DependencyDescriptor, dependency_id,
27 encoding::{FlatMsg, HasSchema},
28 io::{Channel, ChannelHandle, DriverError, Reader, SharedChannel, Writer},
29};
30
31type PendingWiring<In, Out> =
32 Pin<Box<dyn Future<Output = Result<EndpointWiring<In, Out>, SwitchboardError>>>>;
33
34struct PendingUpdate<In, Out> {
35 future: PendingWiring<In, Out>,
36 reuse_inbound: bool,
37 reuse_outbound: bool,
38 next_inbound: Vec<WiringIngress>,
39 next_outbound: Vec<WiringEgress>,
40}
41
42const CONTROL_CHUNK_SIZE: u32 = 64 * 1024;
43const UPDATE_CHANNEL_CAPACITY: u32 = 64 * 1024;
44const INTERNAL_CHANNEL_CAPACITY: u32 = 16 * 1024;
45const DATA_CHUNK_SIZE: u32 = 64 * 1024;
46
47#[derive(Debug, Error)]
49pub enum SwitchboardError {
50 #[error("driver error: {0}")]
52 Driver(#[from] DriverError),
53 #[error("protocol error: {0}")]
55 Protocol(#[from] ProtocolError),
56 #[error("switchboard error: {0}")]
58 Remote(String),
59 #[error("endpoint closed")]
61 EndpointClosed,
62 #[error("no route available")]
64 NoRoute,
65 #[error("switchboard state unavailable")]
67 StateUnavailable,
68}
69
70#[derive(Clone)]
72pub struct Switchboard {
73 request_channel: Channel,
74 _updates_channel: Channel,
75 updates_shared: SharedChannel,
76 state: Arc<Mutex<ClientState>>,
77 next_request_id: Arc<AtomicU64>,
78}
79
80struct ClientState {
81 pending: HashMap<u64, Channel>,
82 endpoints: HashMap<EndpointId, Channel>,
83 queued_updates: HashMap<EndpointId, Vec<Vec<u8>>>,
84 ignored: HashSet<EndpointId>,
85}
86
87pub struct EndpointBuilder<In, Out> {
89 switchboard: Switchboard,
90 input: Cardinality,
91 output: Cardinality,
92 output_backpressure: Backpressure,
93 output_exclusive: bool,
94 _in: PhantomData<In>,
95 _out: PhantomData<Out>,
96}
97
98pub struct EndpointHandle<In, Out> {
100 id: EndpointId,
101 updates: Reader,
102 pending: Option<PendingUpdate<In, Out>>,
103 last_inbound: Vec<WiringIngress>,
104 last_outbound: Vec<WiringEgress>,
105 backpressure: Backpressure,
106 pub io: EndpointIo<In, Out>,
108}
109
110pub struct EndpointIo<In, Out> {
112 pub inbound: Vec<InboundLink<In>>,
114 pub outbound: Vec<RawPublisher<Out>>,
116 outbound_map: HashMap<EndpointId, ChannelHandle>,
117}
118
119pub struct InboundLink<In> {
121 pub from: EndpointId,
123 pub subscriber: RawSubscriber<In>,
125}
126
127struct EndpointWiring<In, Out> {
128 inbound: Vec<InboundLink<In>>,
129 outbound: Vec<RawPublisher<Out>>,
130 outbound_map: HashMap<EndpointId, ChannelHandle>,
131}
132
133pub struct RawPublisher<T> {
135 writer: Writer,
136 _marker: PhantomData<T>,
137}
138
139pub struct RawSubscriber<T> {
141 reader: Reader,
142 _marker: PhantomData<T>,
143}
144
145impl Switchboard {
146 pub async fn attach(request_channel: SharedChannel) -> Result<Self, SwitchboardError> {
148 let request_channel = Channel::attach_shared(request_channel).await?;
149 let updates_channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
150 let updates_shared = updates_channel.share().await?;
151 let updates_reader = updates_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
152
153 let state = Arc::new(Mutex::new(ClientState {
154 pending: HashMap::new(),
155 endpoints: HashMap::new(),
156 queued_updates: HashMap::new(),
157 ignored: HashSet::new(),
158 }));
159
160 let dispatcher_state = Arc::clone(&state);
161 spawn_dispatcher(updates_reader, dispatcher_state);
162 debug!(
163 updates_channel = updates_shared.raw(),
164 "switchboard: dispatcher spawned"
165 );
166
167 Ok(Self {
168 request_channel,
169 _updates_channel: updates_channel,
170 updates_shared,
171 state,
172 next_request_id: Arc::new(AtomicU64::new(1)),
173 })
174 }
175
176 pub fn endpoint<In, Out>(&self) -> EndpointBuilder<In, Out>
178 where
179 In: FlatMsg + HasSchema + Send + Unpin + 'static,
180 Out: FlatMsg + HasSchema + Send + Unpin + 'static,
181 {
182 EndpointBuilder {
183 switchboard: self.clone(),
184 input: Cardinality::One,
185 output: Cardinality::One,
186 output_backpressure: Backpressure::Park,
187 output_exclusive: false,
188 _in: PhantomData,
189 _out: PhantomData,
190 }
191 }
192
193 pub async fn connect_ids(
195 &self,
196 from: EndpointId,
197 to: EndpointId,
198 ) -> Result<(), SwitchboardError> {
199 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
200 let response = self
201 .send_request(Message::ConnectRequest {
202 request_id,
203 from,
204 to,
205 reply_channel: self.updates_shared.raw(),
206 })
207 .await?;
208
209 match response {
210 Message::ResponseOk { .. } => Ok(()),
211 Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
212 _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
213 }
214 }
215
216 pub async fn connect<In, Out, Common>(
218 &self,
219 from: &EndpointHandle<In, Common>,
220 to: &EndpointHandle<Common, Out>,
221 ) -> Result<(), SwitchboardError>
222 where
223 In: FlatMsg + Send + Unpin + 'static,
224 Out: FlatMsg + Send + Unpin + 'static,
225 Common: FlatMsg + Send + Unpin + 'static,
226 {
227 self.connect_ids(from.id, to.id).await
228 }
229
230 async fn register_endpoint(
231 &self,
232 directions: EndpointDirections,
233 ) -> Result<EndpointId, SwitchboardError> {
234 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
235 debug!(request_id, "switchboard: registering endpoint");
236 let response = self
237 .send_request(Message::RegisterRequest {
238 request_id,
239 directions,
240 updates_channel: self.updates_shared.raw(),
241 })
242 .await?;
243
244 match response {
245 Message::ResponseRegister { endpoint_id, .. } => Ok(endpoint_id),
246 Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
247 _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
248 }
249 }
250
251 pub async fn adopt_output_channel<Out>(
253 &self,
254 channel: SharedChannel,
255 backpressure: Backpressure,
256 mode: AdoptMode,
257 ) -> Result<EndpointId, SwitchboardError>
258 where
259 Out: FlatMsg + HasSchema + Send + Unpin + 'static,
260 {
261 let directions = EndpointDirections::new(
262 Direction::new(Out::SCHEMA.hash, Cardinality::Zero, Backpressure::Park),
263 Direction::new(Out::SCHEMA.hash, Cardinality::One, backpressure).with_exclusive(true),
264 );
265 self.adopt_endpoint(directions, channel, mode).await
266 }
267
268 pub async fn adopt_endpoint(
270 &self,
271 directions: EndpointDirections,
272 channel: SharedChannel,
273 mode: AdoptMode,
274 ) -> Result<EndpointId, SwitchboardError> {
275 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
276 debug!(request_id, "switchboard: adopting endpoint");
277 let response = self
278 .send_request(Message::AdoptRequest {
279 request_id,
280 directions,
281 updates_channel: self.updates_shared.raw(),
282 channel: channel.raw(),
283 mode,
284 })
285 .await?;
286
287 match response {
288 Message::ResponseRegister { endpoint_id, .. } => {
289 let mut guard = self.state()?;
290 guard.ignored.insert(endpoint_id);
291 Ok(endpoint_id)
292 }
293 Message::ResponseError { message, .. } => Err(SwitchboardError::Remote(message)),
294 _ => Err(SwitchboardError::Protocol(ProtocolError::UnknownPayload)),
295 }
296 }
297
298 async fn send_request(&self, message: Message) -> Result<Message, SwitchboardError> {
299 let request_id = request_id_for(&message);
300 debug!(request_id, "switchboard: sending request");
301 let response_channel = Channel::create(INTERNAL_CHANNEL_CAPACITY).await?;
302 let mut response_reader = response_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
303
304 {
305 let mut guard = self.state()?;
306 guard.pending.insert(request_id, response_channel.clone());
307 }
308
309 let mut writer = self.request_channel.publish_weak().await?;
310 debug!(request_id, "switchboard: request channel opened");
311 let bytes = encode_message(&message)?;
312 writer.send(bytes).await?;
313 debug!(request_id, "switchboard: request sent");
314
315 let response = match response_reader.next().await {
316 Some(Ok(frame)) => decode_message(&frame.payload)?,
317 Some(Err(err)) => return Err(SwitchboardError::Driver(err)),
318 None => return Err(SwitchboardError::EndpointClosed),
319 };
320 debug!(request_id, "switchboard: received response");
321
322 {
323 let mut guard = self.state()?;
324 guard.pending.remove(&request_id);
325 }
326
327 response_channel.delete().await?;
328
329 Ok(response)
330 }
331
332 async fn register_endpoint_channel(
333 &self,
334 endpoint_id: EndpointId,
335 channel: Channel,
336 ) -> Result<(), SwitchboardError> {
337 let queued = {
338 let mut guard = self.state()?;
339 guard.ignored.remove(&endpoint_id);
340 guard.endpoints.insert(endpoint_id, channel.clone());
341 guard.queued_updates.remove(&endpoint_id)
342 };
343
344 if let Some(pending) = queued {
345 for payload in pending {
346 forward_bytes(channel.clone(), payload).await?;
347 }
348 }
349
350 Ok(())
351 }
352
353 fn state(&self) -> Result<std::sync::MutexGuard<'_, ClientState>, SwitchboardError> {
354 self.state
355 .lock()
356 .map_err(|_| SwitchboardError::StateUnavailable)
357 }
358}
359
360impl Dependency for Switchboard {
361 type Handle = SharedChannel;
362 type Error = SwitchboardError;
363
364 const DESCRIPTOR: DependencyDescriptor = DependencyDescriptor {
365 name: "selium::Switchboard",
366 id: dependency_id!("selium.switchboard.singleton"),
367 };
368
369 async fn from_handle(handle: Self::Handle) -> Result<Self, Self::Error> {
370 Switchboard::attach(handle).await
371 }
372}
373
374impl<In, Out> EndpointBuilder<In, Out>
375where
376 In: FlatMsg + HasSchema + Send + Unpin + 'static,
377 Out: FlatMsg + HasSchema + Send + Unpin + 'static,
378{
379 pub fn inputs(mut self, cardinality: Cardinality) -> Self {
381 self.input = cardinality;
382 self
383 }
384
385 pub fn outputs(mut self, cardinality: Cardinality) -> Self {
387 self.output = cardinality;
388 self
389 }
390
391 pub fn output_backpressure(mut self, backpressure: Backpressure) -> Self {
393 self.output_backpressure = backpressure;
394 self
395 }
396
397 pub fn output_exclusive(mut self, exclusive: bool) -> Self {
399 self.output_exclusive = exclusive;
400 self
401 }
402
403 pub async fn register(self) -> Result<EndpointHandle<In, Out>, SwitchboardError> {
405 let directions = EndpointDirections::new(
406 Direction::new(In::SCHEMA.hash, self.input, Backpressure::Park),
407 Direction::new(Out::SCHEMA.hash, self.output, self.output_backpressure)
408 .with_exclusive(self.output_exclusive),
409 );
410
411 let internal_channel = Channel::create(INTERNAL_CHANNEL_CAPACITY).await?;
412 let updates = internal_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
413
414 let endpoint_id = self.switchboard.register_endpoint(directions).await?;
415 self.switchboard
416 .register_endpoint_channel(endpoint_id, internal_channel)
417 .await?;
418
419 Ok(EndpointHandle {
420 id: endpoint_id,
421 updates,
422 pending: None,
423 last_inbound: Vec::new(),
424 last_outbound: Vec::new(),
425 backpressure: self.output_backpressure,
426 io: EndpointIo::new(),
427 })
428 }
429}
430
431impl<In, Out> EndpointHandle<In, Out> {
432 pub fn id(&self) -> EndpointId {
434 self.id
435 }
436
437 pub fn get_id(&self) -> EndpointId {
439 self.id
440 }
441
442 pub fn outbound_handle(&self, target: EndpointId) -> Option<ChannelHandle> {
444 self.io.outbound_handle(target)
445 }
446
447 pub fn backpressure(&self) -> Backpressure {
449 self.backpressure
450 }
451}
452
453impl<In, Out> EndpointHandle<In, Out>
454where
455 In: FlatMsg + Send + Unpin + 'static,
456 Out: FlatMsg + Send + Unpin + 'static,
457{
458 pub(crate) fn poll_updates(&mut self, cx: &mut Context<'_>) -> Result<(), SwitchboardError> {
459 loop {
460 let mut completed = None;
461 if let Some(pending) = self.pending.as_mut() {
462 match pending.future.as_mut().poll(cx) {
463 Poll::Ready(Ok(wiring)) => completed = Some(Ok(wiring)),
464 Poll::Ready(Err(err)) => completed = Some(Err(err)),
465 Poll::Pending => {}
466 }
467 }
468
469 if let Some(result) = completed {
470 let pending = self.pending.take().expect("pending wiring");
471 match result {
472 Ok(mut wiring) => {
473 if pending.reuse_inbound {
474 wiring.inbound = std::mem::take(&mut self.io.inbound);
475 }
476 if pending.reuse_outbound {
477 wiring.outbound = std::mem::take(&mut self.io.outbound);
478 wiring.outbound_map = std::mem::take(&mut self.io.outbound_map);
479 }
480 self.io.apply_wiring(wiring);
481 self.last_inbound = pending.next_inbound;
482 self.last_outbound = pending.next_outbound;
483 }
484 Err(err) => return Err(err),
485 }
486 }
487
488 match Pin::new(&mut self.updates).poll_next(cx) {
489 Poll::Ready(Some(Ok(frame))) => {
490 if frame.payload.is_empty() {
491 continue;
492 }
493 let message = decode_message(&frame.payload)?;
494 if let Message::WiringUpdate {
495 inbound, outbound, ..
496 } = message
497 {
498 let reuse_inbound = inbound == self.last_inbound;
499 let reuse_outbound = outbound == self.last_outbound;
500 if reuse_inbound && reuse_outbound {
501 continue;
502 }
503
504 let next_inbound = inbound.clone();
505 let next_outbound = outbound.clone();
506
507 let future = Box::pin(async move {
508 let inbound_links = if reuse_inbound {
509 Vec::new()
510 } else {
511 build_inbound(inbound).await?
512 };
513 let (outbound_links, outbound_map) = if reuse_outbound {
514 (Vec::new(), HashMap::new())
515 } else {
516 build_outbound(outbound).await?
517 };
518
519 Ok(EndpointWiring {
520 inbound: inbound_links,
521 outbound: outbound_links,
522 outbound_map,
523 })
524 });
525
526 self.pending = Some(PendingUpdate {
527 future,
528 reuse_inbound,
529 reuse_outbound,
530 next_inbound,
531 next_outbound,
532 });
533 }
534 }
535 Poll::Ready(Some(Err(err))) => return Err(SwitchboardError::Driver(err)),
536 Poll::Ready(None) => return Err(SwitchboardError::EndpointClosed),
537 Poll::Pending => break,
538 }
539 }
540
541 Ok(())
542 }
543}
544
545impl<In, Out> EndpointIo<In, Out> {
546 pub fn outbound_handle(&self, target: EndpointId) -> Option<ChannelHandle> {
548 self.outbound_map.get(&target).cloned()
549 }
550}
551
552impl<In, Out> EndpointIo<In, Out>
553where
554 In: FlatMsg + Send + Unpin + 'static,
555 Out: FlatMsg + Send + Unpin + 'static,
556{
557 fn new() -> Self {
558 Self {
559 inbound: Vec::new(),
560 outbound: Vec::new(),
561 outbound_map: HashMap::new(),
562 }
563 }
564
565 fn apply_wiring(&mut self, wiring: EndpointWiring<In, Out>) {
566 self.inbound = wiring.inbound;
567 self.outbound = wiring.outbound;
568 self.outbound_map = wiring.outbound_map;
569 }
570}
571
572impl<T> RawPublisher<T> {
573 fn new(writer: Writer) -> Self {
574 Self {
575 writer,
576 _marker: PhantomData,
577 }
578 }
579
580 pub(crate) async fn from_channel_handle(
581 handle: ChannelHandle,
582 ) -> Result<Self, SwitchboardError> {
583 let channel = unsafe { Channel::from_raw(handle) };
584 let writer = channel.publish_weak().await?;
585 Ok(Self::new(writer))
586 }
587}
588
589impl<T> RawSubscriber<T> {
590 fn new(reader: Reader) -> Self {
591 Self {
592 reader,
593 _marker: PhantomData,
594 }
595 }
596}
597
598impl<T> futures::Sink<T> for RawPublisher<T>
599where
600 T: FlatMsg + Send + Unpin + 'static,
601{
602 type Error = SwitchboardError;
603
604 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
605 match Pin::new(&mut self.get_mut().writer).poll_ready(cx) {
606 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
607 Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
608 Poll::Pending => Poll::Pending,
609 }
610 }
611
612 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
613 let encoded = FlatMsg::encode(&item);
614 Pin::new(&mut self.get_mut().writer)
615 .start_send(encoded)
616 .map_err(SwitchboardError::Driver)
617 }
618
619 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
620 match Pin::new(&mut self.get_mut().writer).poll_flush(cx) {
621 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
622 Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
623 Poll::Pending => Poll::Pending,
624 }
625 }
626
627 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
628 match Pin::new(&mut self.get_mut().writer).poll_close(cx) {
629 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
630 Poll::Ready(Err(err)) => Poll::Ready(Err(SwitchboardError::Driver(err))),
631 Poll::Pending => Poll::Pending,
632 }
633 }
634}
635
636impl<T> futures::Stream for RawSubscriber<T>
637where
638 T: FlatMsg + Send + Unpin + 'static,
639{
640 type Item = Result<T, SwitchboardError>;
641
642 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
643 match Pin::new(&mut self.get_mut().reader).poll_next(cx) {
644 Poll::Ready(Some(Ok(frame))) => {
645 let decoded =
646 T::decode(&frame.payload).map_err(|err| SwitchboardError::Protocol(err.into()));
647 Poll::Ready(Some(decoded))
648 }
649 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(SwitchboardError::Driver(err)))),
650 Poll::Ready(None) => Poll::Ready(None),
651 Poll::Pending => Poll::Pending,
652 }
653 }
654}
655
656async fn build_inbound<In>(
657 inbound: Vec<WiringIngress>,
658) -> Result<Vec<InboundLink<In>>, SwitchboardError>
659where
660 In: FlatMsg + Send + Unpin + 'static,
661{
662 let mut inbound_links = Vec::with_capacity(inbound.len());
663 for ingress in inbound {
664 let shared = unsafe { SharedChannel::from_raw(ingress.channel) };
665 let channel = Channel::attach_shared(shared).await?;
666 let reader = channel.subscribe(DATA_CHUNK_SIZE).await?;
667 inbound_links.push(InboundLink {
668 from: ingress.from,
669 subscriber: RawSubscriber::new(reader),
670 });
671 }
672 Ok(inbound_links)
673}
674
675async fn build_outbound<Out>(
676 outbound: Vec<WiringEgress>,
677) -> Result<(Vec<RawPublisher<Out>>, HashMap<EndpointId, ChannelHandle>), SwitchboardError>
678where
679 Out: FlatMsg + Send + Unpin + 'static,
680{
681 let mut outbound_links = Vec::with_capacity(outbound.len());
682 let mut outbound_map = HashMap::with_capacity(outbound.len());
683 for egress in outbound {
684 let shared = unsafe { SharedChannel::from_raw(egress.channel) };
685 let channel = Channel::attach_shared(shared).await?;
686 let writer = channel.publish_weak().await?;
687 outbound_map.insert(egress.to, channel.handle());
688 outbound_links.push(RawPublisher::new(writer));
689 }
690 Ok((outbound_links, outbound_map))
691}
692
693fn request_id_for(message: &Message) -> u64 {
694 match message {
695 Message::RegisterRequest { request_id, .. }
696 | Message::AdoptRequest { request_id, .. }
697 | Message::ConnectRequest { request_id, .. }
698 | Message::ResponseRegister { request_id, .. }
699 | Message::ResponseOk { request_id, .. }
700 | Message::ResponseError { request_id, .. } => *request_id,
701 Message::WiringUpdate { .. } => 0,
702 }
703}
704
705fn spawn_dispatcher(reader: Reader, state: Arc<Mutex<ClientState>>) {
706 #[cfg(test)]
707 {
708 tokio::spawn(async move {
709 drive_updates(reader, state).await;
710 });
711 }
712 #[cfg(not(test))]
713 {
714 selium_userland::spawn(async move {
715 drive_updates(reader, state).await;
716 });
717 }
718}
719
720async fn drive_updates(mut reader: Reader, state: Arc<Mutex<ClientState>>) {
721 while let Some(frame) = reader.next().await {
722 let payload = match frame {
723 Ok(frame) if frame.payload.is_empty() => continue,
724 Ok(frame) => frame.payload,
725 Err(err) => {
726 warn!(?err, "switchboard: update stream failed");
727 break;
728 }
729 };
730
731 let message = match decode_message(&payload) {
732 Ok(message) => message,
733 Err(err) => {
734 warn!(?err, "switchboard: failed to decode update");
735 continue;
736 }
737 };
738
739 match message {
740 Message::ResponseOk { request_id }
741 | Message::ResponseRegister { request_id, .. }
742 | Message::ResponseError { request_id, .. } => {
743 debug!(request_id, "switchboard: dispatching response");
744 let channel = match state.lock() {
745 Ok(mut guard) => guard.pending.remove(&request_id),
746 Err(_) => None,
747 };
748 if let Some(channel) = channel
749 && let Err(err) = forward_bytes(channel, payload).await
750 {
751 warn!(?err, "switchboard: failed to deliver response");
752 }
753 }
754 Message::WiringUpdate { endpoint_id, .. } => {
755 debug!(endpoint_id, "switchboard: dispatching wiring update");
756 let (channel, queued, ignored) = match state.lock() {
757 Ok(mut guard) => {
758 if guard.ignored.contains(&endpoint_id) {
759 (None, None, true)
760 } else if let Some(channel) = guard.endpoints.get(&endpoint_id).cloned() {
761 (Some(channel), None, false)
762 } else {
763 let entry = guard.queued_updates.entry(endpoint_id).or_default();
764 entry.push(payload.clone());
765 (None, Some(()), false)
766 }
767 }
768 Err(_) => (None, Some(()), false),
769 };
770
771 if ignored {
772 continue;
773 }
774
775 if let Some(channel) = channel
776 && let Err(err) = forward_bytes(channel, payload).await
777 {
778 warn!(?err, "switchboard: failed to deliver update");
779 }
780
781 if queued.is_some() {
782 continue;
783 }
784 }
785 _ => {}
786 }
787 }
788}
789
790async fn forward_bytes(channel: Channel, payload: Vec<u8>) -> Result<(), DriverError> {
791 let mut writer = channel.publish_weak().await?;
792 writer.send(payload).await?;
793 Ok(())
794}
795
796#[cfg(test)]
797impl Switchboard {
798 pub async fn new() -> Result<Self, SwitchboardError> {
800 let request_channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
801 let request_shared = request_channel.share().await?;
802 spawn_local_service(request_shared);
803 Switchboard::attach(request_shared).await
804 }
805}
806
807#[cfg(test)]
808fn spawn_local_service(request_channel: SharedChannel) {
809 tokio::spawn(async move {
810 if let Err(err) = run_local_service(request_channel).await {
811 warn!(?err, "switchboard: local service terminated");
812 }
813 });
814}
815
816#[cfg(test)]
817async fn run_local_service(request_channel: SharedChannel) -> Result<(), SwitchboardError> {
818 use selium_switchboard_core::{
819 ChannelKey, SwitchboardCore, SwitchboardError as CoreError, best_compatible_match,
820 };
821
822 let request_channel = Channel::attach_shared(request_channel).await?;
823 let mut reader = request_channel.subscribe(CONTROL_CHUNK_SIZE).await?;
824
825 struct LocalChannelEntry {
826 channel: Channel,
827 shared: SharedChannel,
828 key: ChannelKey,
829 }
830
831 #[derive(Default)]
832 struct EndpointUpdate {
833 inbound: Vec<WiringIngress>,
834 outbound: Vec<WiringEgress>,
835 }
836
837 struct LocalService {
838 core: SwitchboardCore,
839 endpoints: HashMap<EndpointId, Writer>,
840 channels: Vec<LocalChannelEntry>,
841 }
842
843 impl LocalService {
844 fn new() -> Self {
845 Self {
846 core: SwitchboardCore::default(),
847 endpoints: HashMap::new(),
848 channels: Vec::new(),
849 }
850 }
851
852 async fn handle_message(&mut self, message: Message) -> Result<(), SwitchboardError> {
853 match message {
854 Message::RegisterRequest {
855 request_id,
856 directions,
857 updates_channel,
858 } => {
859 self.handle_register(request_id, directions, updates_channel)
860 .await
861 }
862 Message::ConnectRequest {
863 request_id,
864 from,
865 to,
866 reply_channel,
867 } => {
868 self.handle_connect(request_id, from, to, reply_channel)
869 .await
870 }
871 _ => Ok(()),
872 }
873 }
874
875 async fn handle_register(
876 &mut self,
877 request_id: u64,
878 directions: EndpointDirections,
879 updates_channel: u64,
880 ) -> Result<(), SwitchboardError> {
881 let updates_channel = unsafe { SharedChannel::from_raw(updates_channel) };
882 let updates_writer = Channel::attach_shared(updates_channel)
883 .await?
884 .publish_weak()
885 .await?;
886
887 let endpoint_id = self.core.add_endpoint(directions);
888 self.endpoints.insert(endpoint_id, updates_writer);
889
890 if let Err(err) = self.reconcile().await {
891 self.core.remove_endpoint(endpoint_id);
892 self.endpoints.remove(&endpoint_id);
893 self.send_error(updates_channel, request_id, err).await?;
894 return Ok(());
895 }
896
897 let response = Message::ResponseRegister {
898 request_id,
899 endpoint_id,
900 };
901 self.send_response(updates_channel, response).await?;
902 Ok(())
903 }
904
905 async fn handle_connect(
906 &mut self,
907 request_id: u64,
908 from: EndpointId,
909 to: EndpointId,
910 reply_channel: u64,
911 ) -> Result<(), SwitchboardError> {
912 let reply_channel = unsafe { SharedChannel::from_raw(reply_channel) };
913 if let Err(err) = self.core.add_intent(from, to) {
914 self.send_error(reply_channel, request_id, err.into())
915 .await?;
916 return Ok(());
917 }
918
919 if let Err(err) = self.reconcile().await {
920 self.core.remove_intent(from, to);
921 self.send_error(reply_channel, request_id, err).await?;
922 return Ok(());
923 }
924
925 let response = Message::ResponseOk { request_id };
926 self.send_response(reply_channel, response).await?;
927 Ok(())
928 }
929
930 async fn send_response(
931 &self,
932 channel: SharedChannel,
933 message: Message,
934 ) -> Result<(), SwitchboardError> {
935 let channel = Channel::attach_shared(channel).await?;
936 let mut writer = channel.publish_weak().await?;
937 let bytes = encode_message(&message)?;
938 writer.send(bytes).await?;
939 Ok(())
940 }
941
942 async fn send_error(
943 &self,
944 channel: SharedChannel,
945 request_id: u64,
946 err: SwitchboardError,
947 ) -> Result<(), SwitchboardError> {
948 let response = Message::ResponseError {
949 request_id,
950 message: err.to_string(),
951 };
952 self.send_response(channel, response).await
953 }
954
955 async fn reconcile(&mut self) -> Result<(), SwitchboardError> {
956 let solution = self.core.solve().map_err(SwitchboardError::from)?;
957 let wiring = self.apply_solution(solution).await?;
958 self.send_updates(wiring).await?;
959 Ok(())
960 }
961
962 async fn apply_solution(
963 &mut self,
964 solution: selium_switchboard_core::Solution,
965 ) -> Result<HashMap<EndpointId, EndpointUpdate>, SwitchboardError> {
966 let mut available = std::mem::take(&mut self.channels);
967 let mut retained = Vec::with_capacity(solution.channels.len());
968 let mut resolved: Vec<SharedChannel> = Vec::with_capacity(solution.channels.len());
969
970 for spec in &solution.channels {
971 let desired_key = spec.key().clone();
972 let position = available
973 .iter()
974 .position(|entry| entry.key == desired_key)
975 .or_else(|| {
976 let keys: Vec<ChannelKey> =
977 available.iter().map(|entry| entry.key.clone()).collect();
978 best_compatible_match(&keys, &desired_key)
979 });
980
981 if let Some(pos) = position {
982 let mut entry = available.swap_remove(pos);
983 if entry.key != desired_key {
984 entry.key = desired_key.clone();
985 }
986 resolved.push(entry.shared);
987 retained.push(entry);
988 } else {
989 let entry = create_local_channel(desired_key.clone()).await?;
990 resolved.push(entry.shared);
991 retained.push(entry);
992 }
993 }
994
995 for entry in available {
996 if let Err(err) = entry.channel.drain().await {
997 warn!(?err, "switchboard: failed to drain channel");
998 }
999 if let Err(err) = entry.channel.delete().await {
1000 warn!(?err, "switchboard: failed to delete channel");
1001 }
1002 }
1003
1004 self.channels = retained;
1005
1006 let mut wiring: HashMap<EndpointId, EndpointUpdate> = HashMap::new();
1007 for route in &solution.routes {
1008 for flow in &route.flows {
1009 let handle = resolved
1010 .get(flow.channel)
1011 .ok_or(SwitchboardError::from(CoreError::Unsolveable))?;
1012 wiring
1013 .entry(flow.producer)
1014 .or_default()
1015 .outbound
1016 .push(WiringEgress {
1017 to: flow.consumer,
1018 channel: handle.raw(),
1019 });
1020 wiring
1021 .entry(flow.consumer)
1022 .or_default()
1023 .inbound
1024 .push(WiringIngress {
1025 from: flow.producer,
1026 channel: handle.raw(),
1027 });
1028 }
1029 }
1030
1031 for entry in wiring.values_mut() {
1032 entry
1033 .inbound
1034 .sort_unstable_by_key(|ingress| (ingress.channel, ingress.from));
1035 entry.inbound.dedup_by_key(|ingress| ingress.channel);
1036 entry
1037 .outbound
1038 .sort_unstable_by_key(|egress| (egress.channel, egress.to));
1039 entry.outbound.dedup_by_key(|egress| egress.channel);
1040 }
1041
1042 Ok(wiring)
1043 }
1044
1045 async fn send_updates(
1046 &mut self,
1047 mut wiring: HashMap<EndpointId, EndpointUpdate>,
1048 ) -> Result<(), SwitchboardError> {
1049 for (endpoint_id, writer) in &mut self.endpoints {
1050 let update = wiring.remove(endpoint_id).unwrap_or_default();
1051 let message = Message::WiringUpdate {
1052 endpoint_id: *endpoint_id,
1053 inbound: update.inbound,
1054 outbound: update.outbound,
1055 };
1056 let bytes = encode_message(&message)?;
1057 if let Err(err) = writer.send(bytes).await {
1058 warn!(?err, endpoint_id, "switchboard: failed to send update");
1059 }
1060 }
1061 Ok(())
1062 }
1063 }
1064
1065 async fn create_local_channel(key: ChannelKey) -> Result<LocalChannelEntry, SwitchboardError> {
1066 let channel = Channel::create(UPDATE_CHANNEL_CAPACITY).await?;
1067 let shared = channel.share().await?;
1068 Ok(LocalChannelEntry {
1069 channel,
1070 shared,
1071 key,
1072 })
1073 }
1074
1075 let mut service = LocalService::new();
1076
1077 while let Some(frame) = reader.next().await {
1078 match frame {
1079 Ok(frame) => {
1080 if frame.payload.is_empty() {
1081 continue;
1082 }
1083 let message = decode_message(&frame.payload)?;
1084 service.handle_message(message).await?;
1085 }
1086 Err(err) => return Err(SwitchboardError::Driver(err)),
1087 }
1088 }
1089
1090 Ok(())
1091}
1092
1093#[cfg(test)]
1094impl From<selium_switchboard_core::SwitchboardError> for SwitchboardError {
1095 fn from(value: selium_switchboard_core::SwitchboardError) -> Self {
1096 SwitchboardError::Remote(value.to_string())
1097 }
1098}